From 0c629f7cee5758c381c2f7807dc1523067d91689 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 30 Jan 2026 18:53:27 -0800 Subject: [PATCH 001/141] added `BidiAction` and `BidiFlow` --- go/ai/embedder.go | 8 +- go/ai/evaluator.go | 10 +- go/ai/generate.go | 12 +- go/ai/prompt.go | 12 +- go/ai/resource.go | 14 +- go/ai/retriever.go | 8 +- go/core/action.go | 317 +++++++++++++++++++++++++++++++---- go/core/action_test.go | 16 +- go/core/api/action.go | 4 +- go/core/background_action.go | 28 ++-- go/core/flow.go | 63 +++---- go/core/flow_test.go | 304 ++++++++++++++++++++++++++++++++- go/genkit/genkit.go | 14 +- 13 files changed, 683 insertions(+), 127 deletions(-) diff --git a/go/ai/embedder.go b/go/ai/embedder.go index b84e3df81b..d93d5b52fa 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -85,7 +85,7 @@ type EmbedderOptions struct { // embedder is an action with functions specific to converting documents to multidimensional vectors such as Embed(). type embedder struct { - core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}] + core.Action[*EmbedRequest, *EmbedResponse, struct{}, struct{}] } // NewEmbedder creates a new [Embedder]. @@ -127,7 +127,7 @@ func NewEmbedder(name string, opts *EmbedderOptions, fn EmbedderFunc) Embedder { } return &embedder{ - ActionDef: *core.NewAction(name, api.ActionTypeEmbedder, metadata, inputSchema, fn), + Action: *core.NewAction(name, api.ActionTypeEmbedder, metadata, inputSchema, fn), } } @@ -143,12 +143,12 @@ func DefineEmbedder(r api.Registry, name string, opts *EmbedderOptions, fn Embed // It will try to resolve the embedder dynamically if the embedder is not found. // It returns nil if the embedder was not resolved. func LookupEmbedder(r api.Registry, name string) Embedder { - action := core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, api.ActionTypeEmbedder, name) + action := core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}, struct{}](r, api.ActionTypeEmbedder, name) if action == nil { return nil } return &embedder{ - ActionDef: *action, + Action: *action, } } diff --git a/go/ai/evaluator.go b/go/ai/evaluator.go index aa536fac9b..dd79a511ba 100644 --- a/go/ai/evaluator.go +++ b/go/ai/evaluator.go @@ -72,7 +72,7 @@ func (e EvaluatorRef) Config() any { // evaluator is an action with functions specific to evaluating a dataset. type evaluator struct { - core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}] + core.Action[*EvaluatorRequest, *EvaluatorResponse, struct{}, struct{}] } // Example is a single example that requires evaluation @@ -190,7 +190,7 @@ func NewEvaluator(name string, opts *EvaluatorOptions, fn EvaluatorFunc) Evaluat } return &evaluator{ - ActionDef: *core.NewAction(name, api.ActionTypeEvaluator, metadata, inputSchema, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) { + Action: *core.NewAction(name, api.ActionTypeEvaluator, metadata, inputSchema, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) { var results []EvaluationResult for _, datapoint := range req.Dataset { if datapoint.TestCaseId == "" { @@ -275,7 +275,7 @@ func NewBatchEvaluator(name string, opts *EvaluatorOptions, fn BatchEvaluatorFun } return &evaluator{ - ActionDef: *core.NewAction(name, api.ActionTypeEvaluator, metadata, nil, fn), + Action: *core.NewAction(name, api.ActionTypeEvaluator, metadata, nil, fn), } } @@ -291,12 +291,12 @@ func DefineBatchEvaluator(r api.Registry, name string, opts *EvaluatorOptions, f // LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator]. // It returns nil if the evaluator was not defined. func LookupEvaluator(r api.Registry, name string) Evaluator { - action := core.ResolveActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, api.ActionTypeEvaluator, name) + action := core.ResolveActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}, struct{}](r, api.ActionTypeEvaluator, name) if action == nil { return nil } return &evaluator{ - ActionDef: *action, + Action: *action, } } diff --git a/go/ai/generate.go b/go/ai/generate.go index 003eb0b653..6aeb1e6642 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -71,11 +71,11 @@ type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResp // model is an action with functions specific to model generation such as Generate(). type model struct { - core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk] + core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk, struct{}] } // generateAction is the type for a utility model generation action that takes in a GenerateActionOptions instead of a ModelRequest. -type generateAction = core.ActionDef[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk] +type generateAction = core.Action[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk, struct{}] // result is a generic struct for parallel operation results with index, value, and error. type result[T any] struct { @@ -191,12 +191,12 @@ func DefineModel(r api.Registry, name string, opts *ModelOptions, fn ModelFunc) // It will try to resolve the model dynamically if the model is not found. // It returns nil if the model was not resolved. func LookupModel(r api.Registry, name string) Model { - action := core.ResolveActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, api.ActionTypeModel, name) + action := core.ResolveActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk, struct{}](r, api.ActionTypeModel, name) if action == nil { return nil } return &model{ - ActionDef: *action, + Action: *action, } } @@ -699,7 +699,7 @@ func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamC return nil, core.NewError(core.INVALID_ARGUMENT, "Model.Generate: generate called on a nil model; check that all models are defined") } - return m.ActionDef.Run(ctx, req, cb) + return m.Action.Run(ctx, req, cb) } // supportsConstrained returns whether the model supports constrained output. @@ -708,7 +708,7 @@ func (m *model) supportsConstrained(hasTools bool) bool { return false } - metadata := m.ActionDef.Desc().Metadata + metadata := m.Action.Desc().Metadata if metadata == nil { return false } diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 4d0151c4c8..b176805b36 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -52,7 +52,7 @@ type Prompt interface { // prompt is a prompt template that can be executed to generate a model response. type prompt struct { - core.ActionDef[any, *GenerateActionOptions, struct{}] + core.Action[any, *GenerateActionOptions, struct{}, struct{}] promptOptions registry api.Registry } @@ -124,7 +124,7 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { metadata["prompt"] = promptMetadata } - p.ActionDef = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, metadata, p.InputSchema, p.buildRequest) + p.Action = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, metadata, p.InputSchema, p.buildRequest) return p } @@ -132,13 +132,13 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { // LookupPrompt looks up a [Prompt] registered by [DefinePrompt]. // It returns nil if the prompt was not defined. func LookupPrompt(r api.Registry, name string) Prompt { - action := core.ResolveActionFor[any, *GenerateActionOptions, struct{}](r, api.ActionTypeExecutablePrompt, name) + action := core.ResolveActionFor[any, *GenerateActionOptions, struct{}, struct{}](r, api.ActionTypeExecutablePrompt, name) if action == nil { return nil } return &prompt{ - ActionDef: *action, - registry: r, + Action: *action, + registry: r, } } @@ -312,7 +312,7 @@ func (p *prompt) Render(ctx context.Context, input any) (*GenerateActionOptions, // Desc returns a descriptor of the prompt with resolved schema references. func (p *prompt) Desc() api.ActionDesc { - desc := p.ActionDef.Desc() + desc := p.Action.Desc() promptMeta := desc.Metadata["prompt"].(map[string]any) if inputMeta, ok := promptMeta["input"].(map[string]any); ok { if inputSchema, ok := inputMeta["schema"].(map[string]any); ok { diff --git a/go/ai/resource.go b/go/ai/resource.go index e84ca1193c..18e1823019 100644 --- a/go/ai/resource.go +++ b/go/ai/resource.go @@ -109,7 +109,7 @@ type ResourceFunc = func(context.Context, *ResourceInput) (*ResourceOutput, erro // It holds the underlying core action and allows looking up resources // by name without knowing their specific input/output api. type resource struct { - core.ActionDef[*ResourceInput, *ResourceOutput, struct{}] + core.Action[*ResourceInput, *ResourceOutput, struct{}, struct{}] } // Resource represents an instance of a resource. @@ -129,7 +129,7 @@ type Resource interface { // DefineResource creates a resource and registers it with the given Registry. func DefineResource(r api.Registry, name string, opts *ResourceOptions, fn ResourceFunc) Resource { metadata := resourceMetadata(name, opts) - return &resource{ActionDef: *core.DefineAction(r, name, api.ActionTypeResource, metadata, nil, fn)} + return &resource{Action: *core.DefineAction(r, name, api.ActionTypeResource, metadata, nil, fn)} } // NewResource creates a resource but does not register it in the registry. @@ -137,7 +137,7 @@ func DefineResource(r api.Registry, name string, opts *ResourceOptions, fn Resou func NewResource(name string, opts *ResourceOptions, fn ResourceFunc) Resource { metadata := resourceMetadata(name, opts) metadata["dynamic"] = true - return &resource{ActionDef: *core.NewAction(name, api.ActionTypeResource, metadata, nil, fn)} + return &resource{Action: *core.NewAction(name, api.ActionTypeResource, metadata, nil, fn)} } // resourceMetadata creates the metadata common to both DefineResource and NewResource. @@ -227,8 +227,8 @@ func (r *resource) Execute(ctx context.Context, input *ResourceInput) (*Resource // FindMatchingResource finds a resource that matches the given URI. func FindMatchingResource(r api.Registry, uri string) (Resource, *ResourceInput, error) { for _, a := range r.ListActions() { - if action, ok := a.(*core.ActionDef[*ResourceInput, *ResourceOutput, struct{}]); ok { - res := &resource{ActionDef: *action} + if action, ok := a.(*core.Action[*ResourceInput, *ResourceOutput, struct{}, struct{}]); ok { + res := &resource{Action: *action} if res.Matches(uri) { variables, err := res.ExtractVariables(uri) if err != nil { @@ -244,9 +244,9 @@ func FindMatchingResource(r api.Registry, uri string) (Resource, *ResourceInput, // LookupResource looks up the resource in the registry by provided name and returns it. func LookupResource(r api.Registry, name string) Resource { - action := core.ResolveActionFor[*ResourceInput, *ResourceOutput, struct{}](r, api.ActionTypeResource, name) + action := core.ResolveActionFor[*ResourceInput, *ResourceOutput, struct{}, struct{}](r, api.ActionTypeResource, name) if action == nil { return nil } - return &resource{ActionDef: *action} + return &resource{Action: *action} } diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 64f048d6d8..9ab97f17ce 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -40,7 +40,7 @@ type Retriever interface { // retriever is an action with functions specific to document retrieval such as Retrieve(). type retriever struct { - core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}] + core.Action[*RetrieverRequest, *RetrieverResponse, struct{}, struct{}] } // RetrieverArg is the interface for retriever arguments. It can either be the retriever action itself or a reference to be looked up. @@ -121,7 +121,7 @@ func NewRetriever(name string, opts *RetrieverOptions, fn RetrieverFunc) Retriev } return &retriever{ - ActionDef: *core.NewAction(name, api.ActionTypeRetriever, metadata, inputSchema, fn), + Action: *core.NewAction(name, api.ActionTypeRetriever, metadata, inputSchema, fn), } } @@ -136,12 +136,12 @@ func DefineRetriever(r api.Registry, name string, opts *RetrieverOptions, fn Ret // It will try to resolve the retriever dynamically if the retriever is not found. // It returns nil if the retriever was not resolved. func LookupRetriever(r api.Registry, name string) Retriever { - action := core.ResolveActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](r, api.ActionTypeRetriever, name) + action := core.ResolveActionFor[*RetrieverRequest, *RetrieverResponse, struct{}, struct{}](r, api.ActionTypeRetriever, name) if action == nil { return nil } return &retriever{ - ActionDef: *action, + Action: *action, } } diff --git a/go/core/action.go b/go/core/action.go index 50c1aa63a5..9d5e4c31a8 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -19,7 +19,9 @@ package core import ( "context" "encoding/json" + "iter" "reflect" + "sync" "time" "github.com/firebase/genkit/go/core/api" @@ -38,19 +40,25 @@ type StreamingFunc[In, Out, Stream any] = func(context.Context, In, StreamCallba // StreamCallback is a function that is called during streaming to return the next chunk of the stream. type StreamCallback[Stream any] = func(context.Context, Stream) error -// An ActionDef is a named, observable operation that underlies all Genkit primitives. +// BidiFunc is the function signature for bidirectional streaming actions. +// It receives initialization data, reads inputs from inCh, and writes +// streamed outputs to outCh. It returns a final output when complete. +type BidiFunc[In, Out, Stream, Init any] = func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) + +// An Action is a named, observable operation that underlies all Genkit primitives. // It consists of a function that takes an input of type I and returns an output // of type O, optionally streaming values of type S incrementally by invoking a callback. // It optionally has other metadata, like a description and JSON Schemas for its input and // output which it validates against. // -// Each time an ActionDef is run, it results in a new trace span. +// Each time an Action is run, it results in a new trace span. // // For internal use only. -type ActionDef[In, Out, Stream any] struct { - fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. - desc *api.ActionDesc // Descriptor of the action. - registry api.Registry // Registry for schema resolution. Set when registered. +type Action[In, Out, Stream, Init any] struct { + fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. + bidiFn BidiFunc[In, Out, Stream, Init] // Non-nil for bidi actions only. + desc *api.ActionDesc // Descriptor of the action. + registry api.Registry // Registry for schema resolution. Set when registered. } type noStream = func(context.Context, struct{}) error @@ -63,8 +71,8 @@ func NewAction[In, Out any]( metadata map[string]any, inputSchema map[string]any, fn Func[In, Out], -) *ActionDef[In, Out, struct{}] { - return newAction(name, atype, metadata, inputSchema, +) *Action[In, Out, struct{}, struct{}] { + return newAction[In, Out, struct{}, struct{}](name, atype, metadata, inputSchema, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) @@ -78,8 +86,69 @@ func NewStreamingAction[In, Out, Stream any]( metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { - return newAction(name, atype, metadata, inputSchema, fn) +) *Action[In, Out, Stream, struct{}] { + return newAction[In, Out, Stream, struct{}](name, atype, metadata, inputSchema, fn) +} + +// ActionOptions configures a bidi action. Nil schema fields are inferred from type parameters. +type ActionOptions struct { + Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. + InputSchema map[string]any // JSON schema for the action's input. Inferred from In if nil. + OutputSchema map[string]any // JSON schema for the action's output. Inferred from Out if nil. + StreamSchema map[string]any // JSON schema for streamed chunks. Inferred from Stream if nil. Not used for non-streaming actions. + InitSchema map[string]any // JSON schema for bidi initialization data. Inferred from Init if nil. Not used for non-bidi actions. +} + +// NewBidiAction creates a new bidirectional streaming [Action] without registering it. +func NewBidiAction[In, Out, Stream, Init any]( + name string, + atype api.ActionType, + opts *ActionOptions, + fn BidiFunc[In, Out, Stream, Init], +) *Action[In, Out, Stream, Init] { + if opts == nil { + opts = &ActionOptions{} + } + + metadata := opts.Metadata + if metadata == nil { + metadata = map[string]any{} + } + metadata["bidi"] = true + + a := newAction[In, Out, Stream, Init](name, atype, metadata, opts.InputSchema, wrapBidiAsStreaming(fn)) + a.bidiFn = fn + + if opts.OutputSchema != nil { + a.desc.OutputSchema = opts.OutputSchema + } + if opts.StreamSchema != nil { + a.desc.StreamSchema = opts.StreamSchema + } + + if opts.InitSchema != nil { + a.desc.InitSchema = opts.InitSchema + } else { + var init Init + if reflect.ValueOf(init).Kind() != reflect.Invalid { + a.desc.InitSchema = InferSchemaMap(init) + } + } + + return a +} + +// DefineBidiAction creates and registers a bidirectional streaming [Action]. +func DefineBidiAction[In, Out, Stream, Init any]( + r api.Registry, + name string, + atype api.ActionType, + opts *ActionOptions, + fn BidiFunc[In, Out, Stream, Init], +) *Action[In, Out, Stream, Init] { + a := NewBidiAction(name, atype, opts, fn) + a.Register(r) + return a } // DefineAction creates a new non-streaming Action and registers it. @@ -91,8 +160,8 @@ func DefineAction[In, Out any]( metadata map[string]any, inputSchema map[string]any, fn Func[In, Out], -) *ActionDef[In, Out, struct{}] { - return defineAction(r, name, atype, metadata, inputSchema, +) *Action[In, Out, struct{}, struct{}] { + return defineAction[In, Out, struct{}, struct{}](r, name, atype, metadata, inputSchema, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) @@ -107,20 +176,20 @@ func DefineStreamingAction[In, Out, Stream any]( metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { - return defineAction(r, name, atype, metadata, inputSchema, fn) +) *Action[In, Out, Stream, struct{}] { + return defineAction[In, Out, Stream, struct{}](r, name, atype, metadata, inputSchema, fn) } // defineAction creates an action and registers it with the given Registry. -func defineAction[In, Out, Stream any]( +func defineAction[In, Out, Stream, Init any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { - a := newAction(name, atype, metadata, inputSchema, fn) +) *Action[In, Out, Stream, Init] { + a := newAction[In, Out, Stream, Init](name, atype, metadata, inputSchema, fn) a.Register(r) return a } @@ -128,13 +197,13 @@ func defineAction[In, Out, Stream any]( // newAction creates a new Action with the given name and arguments. // If registry is nil, tracing state is left nil to be set later. // If inputSchema is nil, it is inferred from In. -func newAction[In, Out, Stream any]( +func newAction[In, Out, Stream, Init any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { +) *Action[In, Out, Stream, Init] { if inputSchema == nil { var i In if reflect.ValueOf(i).Kind() != reflect.Invalid { @@ -148,12 +217,18 @@ func newAction[In, Out, Stream any]( outputSchema = InferSchemaMap(o) } + var s Stream + var streamSchema map[string]any + if reflect.ValueOf(s).Kind() != reflect.Invalid { + streamSchema = InferSchemaMap(s) + } + var description string if desc, ok := metadata["description"].(string); ok { description = desc } - return &ActionDef[In, Out, Stream]{ + return &Action[In, Out, Stream, Init]{ fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { return fn(ctx, input, cb) }, @@ -164,16 +239,17 @@ func newAction[In, Out, Stream any]( Description: description, InputSchema: inputSchema, OutputSchema: outputSchema, + StreamSchema: streamSchema, Metadata: metadata, }, } } // Name returns the Action's Name. -func (a *ActionDef[In, Out, Stream]) Name() string { return a.desc.Name } +func (a *Action[In, Out, Stream, Init]) Name() string { return a.desc.Name } // Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { +func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { r, err := a.runWithTelemetry(ctx, input, cb) if err != nil { return base.Zero[Out](), err @@ -182,7 +258,7 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea } // Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { +func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { inputBytes, _ := json.Marshal(input) logger.FromContext(ctx).Debug("Action.Run", "name", a.Name(), @@ -263,7 +339,7 @@ func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input } // RunJSON runs the action with a JSON input, and returns a JSON result. -func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { +func (a *Action[In, Out, Stream, Init]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { r, err := a.RunJSONWithTelemetry(ctx, input, cb) if err != nil { return nil, err @@ -272,7 +348,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw } // RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info. -func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Action[In, Out, Stream, Init]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { return nil, NewError(INVALID_ARGUMENT, err.Error()) @@ -310,27 +386,79 @@ func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, i } // Desc returns a descriptor of the action with resolved schema references. -func (a *ActionDef[In, Out, Stream]) Desc() api.ActionDesc { +func (a *Action[In, Out, Stream, Init]) Desc() api.ActionDesc { return *a.desc } // Register registers the action with the given registry. -func (a *ActionDef[In, Out, Stream]) Register(r api.Registry) { +func (a *Action[In, Out, Stream, Init]) Register(r api.Registry) { a.registry = r r.RegisterAction(a.desc.Key, a) } +// StreamBidi starts a bidirectional streaming connection. +// Returns an error if the action is not a bidi action. +// A trace span is created that remains open for the lifetime of the connection. +func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) { + if a.bidiFn == nil { + return nil, NewError(FAILED_PRECONDITION, "StreamBidi called on non-bidi action %q", a.desc.Name) + } + + ctx, cancel := context.WithCancel(ctx) + conn := &BidiConnection[In, Out, Stream]{ + inputCh: make(chan In, 1), + streamCh: make(chan Stream, 1), + doneCh: make(chan struct{}), + ctx: ctx, + cancel: cancel, + } + + spanMetadata := &tracing.SpanMetadata{ + Name: a.desc.Name, + Type: "action", + Subtype: string(a.desc.Type), + Metadata: make(map[string]string), + } + if flowName := FlowNameFromContext(ctx); flowName != "" { + spanMetadata.Metadata["flow:name"] = flowName + } + + go func() { + defer close(conn.doneCh) + defer close(conn.streamCh) + output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, init, + func(ctx context.Context, init Init) (Out, error) { + start := time.Now() + output, err := a.bidiFn(ctx, init, conn.inputCh, conn.streamCh) + latency := time.Since(start) + if err != nil { + metrics.WriteActionFailure(ctx, a.desc.Name, latency, err) + } else { + metrics.WriteActionSuccess(ctx, a.desc.Name, latency) + } + return output, err + }, + ) + conn.mu.Lock() + conn.output = output + conn.err = err + conn.mu.Unlock() + }() + + return conn, nil +} + // ResolveActionFor returns the action for the given key in the global registry, // or nil if there is none. // It panics if the action is of the wrong api. -func ResolveActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, name string) *ActionDef[In, Out, Stream] { +func ResolveActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.ResolveAction(key) if a == nil { return nil } - return a.(*ActionDef[In, Out, Stream]) + return a.(*Action[In, Out, Stream, Init]) } // LookupActionFor returns the action for the given key in the global registry, @@ -338,12 +466,139 @@ func ResolveActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, // It panics if the action is of the wrong api. // // Deprecated: Use ResolveActionFor. -func LookupActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, name string) *ActionDef[In, Out, Stream] { +func LookupActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.LookupAction(key) if a == nil { return nil } - return a.(*ActionDef[In, Out, Stream]) + return a.(*Action[In, Out, Stream, Init]) +} + +// wrapBidiAsStreaming wraps a BidiFunc into a StreamingFunc for use with Run/RunJSON. +// The input is sent as a single message, and stream chunks are forwarded to the callback. +func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, Init]) StreamingFunc[In, Out, Stream] { + return func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { + inCh := make(chan In, 1) + outCh := make(chan Stream, 1) + doneCh := make(chan struct{}) + + var output Out + var fnErr error + + go func() { + defer close(doneCh) + defer close(outCh) + var init Init + output, fnErr = fn(ctx, init, inCh, outCh) + }() + + // Send the single input and close. + inCh <- input + close(inCh) + + // Forward streamed chunks to the callback. + if cb != nil { + for chunk := range outCh { + if err := cb(ctx, chunk); err != nil { + return base.Zero[Out](), err + } + } + } else { + // Drain the channel even without a callback. + for range outCh { + } + } + + <-doneCh + return output, fnErr + } +} + +// BidiConnection represents an active bidirectional streaming session. +type BidiConnection[In, Out, Stream any] struct { + inputCh chan In + streamCh chan Stream + doneCh chan struct{} + output Out + err error + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex + closed bool +} + +// Send sends an input message to the bidi action. +// Returns an error if the connection is closed or the context is cancelled. +func (c *BidiConnection[In, Out, Stream]) Send(input In) error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return NewError(FAILED_PRECONDITION, "connection is closed") + } + c.mu.Unlock() + + select { + case c.inputCh <- input: + return nil + case <-c.ctx.Done(): + return c.ctx.Err() + case <-c.doneCh: + return NewError(FAILED_PRECONDITION, "action has completed") + } +} + +// Close signals that no more inputs will be sent. +func (c *BidiConnection[In, Out, Stream]) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } + c.closed = true + close(c.inputCh) + return nil +} + +// Receive returns an iterator for receiving streamed response chunks. +// The iterator completes when the action finishes. +func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { + return func(yield func(Stream, error) bool) { + for { + select { + case chunk, ok := <-c.streamCh: + if !ok { + return + } + if !yield(chunk, nil) { + c.cancel() + return + } + case <-c.ctx.Done(): + var zero Stream + yield(zero, c.ctx.Err()) + return + } + } + } +} + +// Output returns the final output after the action completes. +// Blocks until done or context cancelled. +func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) { + select { + case <-c.doneCh: + c.mu.Lock() + defer c.mu.Unlock() + return c.output, c.err + case <-c.ctx.Done(): + var zero Out + return zero, c.ctx.Err() + } +} + +// Done returns a channel that is closed when the connection completes. +func (c *BidiConnection[In, Out, Stream]) Done() <-chan struct{} { + return c.doneCh } diff --git a/go/core/action_test.go b/go/core/action_test.go index 65309d850d..93cd1471eb 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -34,7 +34,7 @@ func inc(_ context.Context, x int, _ noStream) (int, error) { func TestActionRun(t *testing.T) { r := registry.New() - a := defineAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) + a := DefineStreamingAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -46,7 +46,7 @@ func TestActionRun(t *testing.T) { func TestActionRunJSON(t *testing.T) { r := registry.New() - a := defineAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) + a := DefineStreamingAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) input := []byte("3") want := []byte("4") got, err := a.RunJSON(context.Background(), input, nil) @@ -73,7 +73,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int func TestActionStreaming(t *testing.T) { ctx := context.Background() r := registry.New() - a := defineAction(r, "test/count", api.ActionTypeCustom, nil, nil, count) + a := DefineStreamingAction(r, "test/count", api.ActionTypeCustom, nil, nil, count) const n = 3 // Non-streaming. @@ -108,7 +108,7 @@ func TestActionTracing(t *testing.T) { tc := tracing.NewTestOnlyTelemetryClient() tracing.WriteTelemetryImmediate(tc) name := api.NewName("test", "TestTracing-inc") - a := defineAction(r, name, api.ActionTypeCustom, nil, nil, inc) + a := DefineStreamingAction(r, name, api.ActionTypeCustom, nil, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } @@ -309,7 +309,7 @@ func TestResolveActionFor(t *testing.T) { } DefineAction(r, "test/resolvable", api.ActionTypeCustom, nil, nil, fn) - found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/resolvable") + found := ResolveActionFor[int, int, struct{}, struct{}](r, api.ActionTypeCustom, "test/resolvable") if found == nil { t.Fatal("ResolveActionFor returned nil") @@ -322,7 +322,7 @@ func TestResolveActionFor(t *testing.T) { t.Run("returns nil for non-existent action", func(t *testing.T) { r := registry.New() - found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/nonexistent") + found := ResolveActionFor[int, int, struct{}, struct{}](r, api.ActionTypeCustom, "test/nonexistent") if found != nil { t.Errorf("ResolveActionFor returned %v, want nil", found) @@ -338,7 +338,7 @@ func TestLookupActionFor(t *testing.T) { } DefineAction(r, "test/lookupable", api.ActionTypeCustom, nil, nil, fn) - found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/lookupable") + found := LookupActionFor[string, string, struct{}, struct{}](r, api.ActionTypeCustom, "test/lookupable") if found == nil { t.Fatal("LookupActionFor returned nil") @@ -348,7 +348,7 @@ func TestLookupActionFor(t *testing.T) { t.Run("returns nil for non-existent action", func(t *testing.T) { r := registry.New() - found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/missing") + found := LookupActionFor[string, string, struct{}, struct{}](r, api.ActionTypeCustom, "test/missing") if found != nil { t.Errorf("LookupActionFor returned %v, want nil", found) diff --git a/go/core/api/action.go b/go/core/api/action.go index 18ffcfa67b..48e8045596 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -74,5 +74,7 @@ type ActionDesc struct { Description string `json:"description"` // Description of the action. InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. - Metadata map[string]any `json:"metadata"` // Metadata for the action. + Metadata map[string]any `json:"metadata"` // Metadata for the action. + StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema for streamed chunks. + InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema for initialization data. } diff --git a/go/core/background_action.go b/go/core/background_action.go index e6af50399b..c41aaed05a 100644 --- a/go/core/background_action.go +++ b/go/core/background_action.go @@ -45,10 +45,10 @@ type Operation[Out any] struct { // // For internal use only. type BackgroundActionDef[In, Out any] struct { - *ActionDef[In, *Operation[Out], struct{}] + *Action[In, *Operation[Out], struct{}, struct{}] - check *ActionDef[*Operation[Out], *Operation[Out], struct{}] // Sub-action that checks the status of a background operation. - cancel *ActionDef[*Operation[Out], *Operation[Out], struct{}] // Sub-action that cancels a background operation. + check *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] // Sub-action that checks the status of a background operation. + cancel *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] // Sub-action that cancels a background operation. } // Start starts a background operation. @@ -77,7 +77,7 @@ func (b *BackgroundActionDef[In, Out]) SupportsCancel() bool { // Register registers the model with the given registry. func (b *BackgroundActionDef[In, Out]) Register(r api.Registry) { - b.ActionDef.Register(r) + b.Action.Register(r) b.check.Register(r) if b.cancel != nil { b.cancel.Register(r) @@ -140,7 +140,7 @@ func NewBackgroundAction[In, Out any]( return updatedOp, nil }) - var cancelAction *ActionDef[*Operation[Out], *Operation[Out], struct{}] + var cancelAction *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] if cancelFn != nil { cancelAction = NewAction(name, api.ActionTypeCancelOperation, metadata, nil, func(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) { @@ -154,9 +154,9 @@ func NewBackgroundAction[In, Out any]( } return &BackgroundActionDef[In, Out]{ - ActionDef: startAction, - check: checkAction, - cancel: cancelAction, + Action: startAction, + check: checkAction, + cancel: cancelAction, } } @@ -165,22 +165,22 @@ func LookupBackgroundAction[In, Out any](r api.Registry, key string) *Background atype, provider, id := api.ParseKey(key) name := api.NewName(provider, id) - startAction := ResolveActionFor[In, *Operation[Out], struct{}](r, atype, name) + startAction := ResolveActionFor[In, *Operation[Out], struct{}, struct{}](r, atype, name) if startAction == nil { return nil } - checkAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCheckOperation, name) + checkAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}, struct{}](r, api.ActionTypeCheckOperation, name) if checkAction == nil { return nil } - cancelAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCancelOperation, name) + cancelAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}, struct{}](r, api.ActionTypeCancelOperation, name) return &BackgroundActionDef[In, Out]{ - ActionDef: startAction, - check: checkAction, - cancel: cancelAction, + Action: startAction, + check: checkAction, + cancel: cancelAction, } } diff --git a/go/core/flow.go b/go/core/flow.go index ea514365c2..5ec2503478 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -18,7 +18,6 @@ package core import ( "context" - "encoding/json" "errors" "fmt" @@ -27,8 +26,11 @@ import ( "github.com/firebase/genkit/go/internal/base" ) -// A Flow is a user-defined Action. A Flow[In, Out, Stream] represents a function from In to Out. The Stream parameter is for flows that support streaming: providing their results incrementally. -type Flow[In, Out, Stream any] ActionDef[In, Out, Stream] +// A Flow is a user-defined Action. A Flow[In, Out, Stream, Init] represents a function from In to Out. +// The Stream parameter is for flows that support streaming: providing their results incrementally. The Init parameter is for bidi flows. +type Flow[In, Out, Stream, Init any] struct { + *Action[In, Out, Stream, Init] +} // StreamingFlowValue is either a streamed value or a final output of a flow. type StreamingFlowValue[Out, Stream any] struct { @@ -46,14 +48,14 @@ type flowContext struct { } // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. -func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { - return (*Flow[In, Out, struct{}])(DefineAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { +func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { + return &Flow[In, Out, struct{}, struct{}]{DefineAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { fc := &flowContext{ flowName: name, } ctx = flowContextKey.NewContext(ctx, fc) return fn(ctx, input) - })) + })} } // DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. @@ -65,8 +67,8 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // stream the results by invoking the callback periodically, ultimately returning // with a final return value that includes all the streamed data. // Otherwise, it should ignore the callback and just return a result. -func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] { - return (*Flow[In, Out, Stream])(DefineStreamingAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { +func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { + return &Flow[In, Out, Stream, struct{}]{DefineStreamingAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { fc := &flowContext{ flowName: name, } @@ -75,7 +77,17 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St cb = func(context.Context, Stream) error { return nil } } return fn(ctx, input, cb) - })) + })} +} + +// DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. +// Flow context is injected so that [Run] works inside the bidi function. +func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { + wrapped := func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) { + ctx = flowContextKey.NewContext(ctx, &flowContext{flowName: name}) + return fn(ctx, init, inCh, outCh) + } + return &Flow[In, Out, Stream, Init]{DefineBidiAction(r, name, api.ActionTypeFlow, nil, wrapped)} } // Run runs the function f in the context of the current flow @@ -105,29 +117,9 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out }) } -// Name returns the name of the flow. -func (f *Flow[In, Out, Stream]) Name() string { - return (*ActionDef[In, Out, Stream])(f).Name() -} - -// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON. -func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { - return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb) -} - -// RunJSONWithTelemetry runs the flow with JSON input and streaming callback and returns the output as JSON along with telemetry info. -func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { - return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb) -} - -// Desc returns the descriptor of the flow. -func (f *Flow[In, Out, Stream]) Desc() api.ActionDesc { - return (*ActionDef[In, Out, Stream])(f).Desc() -} - // Run runs the flow in the context of another flow. -func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) { - return (*ActionDef[In, Out, Stream])(f).Run(ctx, input, nil) +func (f *Flow[In, Out, Stream, Init]) Run(ctx context.Context, input In) (Out, error) { + return f.Action.Run(ctx, input, nil) } // Stream runs the flow in the context of another flow and streams the output. @@ -142,7 +134,7 @@ func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) // again. // // Otherwise the Stream field of the passed [StreamingFlowValue] holds a streamed result. -func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, Stream], error) bool) { +func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, Stream], error) bool) { return func(yield func(*StreamingFlowValue[Out, Stream], error) bool) { cb := func(ctx context.Context, s Stream) error { if ctx.Err() != nil { @@ -153,7 +145,7 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func( } return nil } - output, err := (*ActionDef[In, Out, Stream])(f).Run(ctx, input, cb) + output, err := f.Action.Run(ctx, input, cb) if err != nil { yield(nil, err) } else { @@ -162,11 +154,6 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func( } } -// Register registers the flow with the given registry. -func (f *Flow[In, Out, Stream]) Register(r api.Registry) { - (*ActionDef[In, Out, Stream])(f).Register(r) -} - var errStop = errors.New("stop") // FlowNameFromContext returns the flow name from context if we're in a flow, empty string otherwise. diff --git a/go/core/flow_test.go b/go/core/flow_test.go index e3c3e6b463..7da8d31778 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -18,9 +18,12 @@ package core import ( "context" + "fmt" "slices" + "strings" "testing" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" ) @@ -69,7 +72,7 @@ func TestRunFlow(t *testing.T) { func TestFlowNameFromContext(t *testing.T) { r := registry.New() - flows := []*Flow[struct{}, string, struct{}]{ + flows := []*Flow[struct{}, string, struct{}, struct{}]{ DefineFlow(r, "DefineFlow", func(ctx context.Context, _ struct{}) (string, error) { return FlowNameFromContext(ctx), nil }), @@ -257,3 +260,302 @@ func TestFlowNameFromContextOutsideFlow(t *testing.T) { } }) } + +func TestBidiActionEcho(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "echo", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var count int + for input := range inCh { + count++ + outCh <- fmt.Sprintf("echo: %s", input) + } + return fmt.Sprintf("processed %d messages", count), nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + // With unbuffered channels, we must send and receive concurrently. + go func() { + conn.Send("hello") + conn.Send("world") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 2 { + t.Fatalf("expected 2 chunks, got %d: %v", len(chunks), chunks) + } + if chunks[0] != "echo: hello" { + t.Errorf("expected 'echo: hello', got %q", chunks[0]) + } + if chunks[1] != "echo: world" { + t.Errorf("expected 'echo: world', got %q", chunks[1]) + } + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "processed 2 messages" { + t.Errorf("expected 'processed 2 messages', got %q", output) + } +} + +func TestBidiActionWithInit(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string + } + + action := NewBidiAction( + "prefixed", api.ActionTypeCustom, nil, + func(ctx context.Context, init Config, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + outCh <- fmt.Sprintf("%s: %s", init.Prefix, input) + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, Config{Prefix: "INFO"}) + if err != nil { + t.Fatal(err) + } + + go func() { + conn.Send("test message") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 1 || chunks[0] != "INFO: test message" { + t.Errorf("unexpected chunks: %v", chunks) + } +} + +func TestBidiConnectionSendAfterClose(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "test", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + conn.Close() + // Wait for completion so we know the state is settled. + <-conn.Done() + + if err := conn.Send("after close"); err == nil { + t.Error("expected error sending after close") + } +} + +func TestBidiConnectionContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + action := NewBidiAction( + "blocking", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + <-ctx.Done() + return "", ctx.Err() + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + cancel() + + _, err = conn.Output() + if err == nil { + t.Error("expected error after context cancellation") + } +} + +func TestBidiFlowRegistration(t *testing.T) { + r := registry.New() + + flow := DefineBidiFlow( + r, "echoFlow", + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + outCh <- input + } + return "done", nil + }, + ) + + if flow.Name() != "echoFlow" { + t.Errorf("expected name 'echoFlow', got %q", flow.Name()) + } + + desc := flow.Desc() + if desc.Type != api.ActionTypeFlow { + t.Errorf("expected type %q, got %q", api.ActionTypeFlow, desc.Type) + } + + // Verify bidi metadata is set. + if bidi, ok := desc.Metadata["bidi"].(bool); !ok || !bidi { + t.Error("expected metadata[\"bidi\"] = true") + } + + // Verify registered in registry. + action := r.LookupAction(desc.Key) + if action == nil { + t.Error("expected action to be registered") + } +} + +func TestBidiFlowEcho(t *testing.T) { + r := registry.New() + ctx := context.Background() + + flow := DefineBidiFlow( + r, "echoFlow", + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var count int + for input := range inCh { + count++ + outCh <- fmt.Sprintf("echo: %s", input) + } + return fmt.Sprintf("processed %d", count), nil + }, + ) + + conn, err := flow.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + go func() { + conn.Send("a") + conn.Send("b") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 2 { + t.Fatalf("expected 2 chunks, got %d", len(chunks)) + } + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "processed 2" { + t.Errorf("expected 'processed 2', got %q", output) + } +} + +func TestBidiFlowCoreRunWorks(t *testing.T) { + r := registry.New() + ctx := context.Background() + + flow := DefineBidiFlow( + r, "withSteps", + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + // core.Run should work inside a BidiFlow. + result, err := Run(ctx, "uppercase", func() (string, error) { + return strings.ToUpper(input), nil + }) + if err != nil { + return "", err + } + outCh <- result + } + return "done", nil + }, + ) + + conn, err := flow.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + go func() { + conn.Send("hello") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 1 || chunks[0] != "HELLO" { + t.Errorf("expected [HELLO], got %v", chunks) + } +} + +func TestBidiActionDone(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "quick", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "finished", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + conn.Close() + <-conn.Done() + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "finished" { + t.Errorf("expected 'finished', got %q", output) + } +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 377fb5e836..79983ce2e3 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -304,7 +304,7 @@ func RegisterAction(g *Genkit, action api.Registerable) { // // handle error // } // fmt.Println(result) // Output: Hello, World! -func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}] { +func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}, struct{}] { return core.DefineFlow(g.reg, name, fn) } @@ -355,10 +355,20 @@ func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *cor // fmt.Println("Stream Chunk:", result.Stream) // Outputs: 1, 2, 3, 4, 5 // } // } -func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { +func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream, struct{}] { return core.DefineStreamingFlow(g.reg, name, fn) } +// DefineBidiFlow defines a bidirectional streaming flow, registers it, and +// returns a [core.Flow] that can be used to start bidi connections. +// +// The provided function receives initialization data, reads inputs from +// a channel, and writes streamed outputs to a channel. It returns a final +// output when complete. +func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.BidiFunc[In, Out, Stream, Init]) *core.Flow[In, Out, Stream, Init] { + return core.DefineBidiFlow[In, Out, Stream, Init](g.reg, name, fn) +} + // Run executes the given function `fn` within the context of the current flow run, // creating a distinct trace span for this step. It's used to add observability // to specific sub-operations within a flow defined by [DefineFlow] or [DefineStreamingFlow]. From f805988d1bdbdfb0fa47213779901f4178e2753e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 30 Jan 2026 19:02:06 -0800 Subject: [PATCH 002/141] added `NewBidiFlow` --- go/core/api/action.go | 18 +++++++++--------- go/core/flow.go | 14 +++++++++++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index 48e8045596..a38958af51 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -68,13 +68,13 @@ const ( // ActionDesc is a descriptor of an action. type ActionDesc struct { - Type ActionType `json:"type"` // Type of the action. - Key string `json:"key"` // Key of the action. - Name string `json:"name"` // Name of the action. - Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. - Metadata map[string]any `json:"metadata"` // Metadata for the action. - StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema for streamed chunks. - InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema for initialization data. + Type ActionType `json:"type"` // Type of the action. + Key string `json:"key"` // Key of the action. + Name string `json:"name"` // Name of the action. + Description string `json:"description"` // Description of the action. + InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. + StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema to validate against the action's streamed chunks. + InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema to validate against the action's initialization data. + Metadata map[string]any `json:"metadata"` // Metadata for the action. } diff --git a/go/core/flow.go b/go/core/flow.go index 5ec2503478..c173a0306c 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -80,14 +80,22 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St })} } -// DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. +// NewBidiFlow creates a bidirectional streaming Flow without registering it. // Flow context is injected so that [Run] works inside the bidi function. -func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { +func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { wrapped := func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) { ctx = flowContextKey.NewContext(ctx, &flowContext{flowName: name}) return fn(ctx, init, inCh, outCh) } - return &Flow[In, Out, Stream, Init]{DefineBidiAction(r, name, api.ActionTypeFlow, nil, wrapped)} + return &Flow[In, Out, Stream, Init]{NewBidiAction(name, api.ActionTypeFlow, nil, wrapped)} +} + +// DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. +// Flow context is injected so that [Run] works inside the bidi function. +func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { + f := NewBidiFlow[In, Out, Stream, Init](name, fn) + f.Register(r) + return f } // Run runs the function f in the context of the current flow From f0f52952e7ee50ac026f28497eaf579d45b38f0c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 30 Jan 2026 19:04:54 -0800 Subject: [PATCH 003/141] Update genkit.go --- go/genkit/genkit.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 79983ce2e3..85936cadce 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -366,7 +366,7 @@ func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.St // a channel, and writes streamed outputs to a channel. It returns a final // output when complete. func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.BidiFunc[In, Out, Stream, Init]) *core.Flow[In, Out, Stream, Init] { - return core.DefineBidiFlow[In, Out, Stream, Init](g.reg, name, fn) + return core.DefineBidiFlow(g.reg, name, fn) } // Run executes the given function `fn` within the context of the current flow run, From 9b9bdea0a953112c8217c44c44600ee50b031437 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 30 Jan 2026 19:06:57 -0800 Subject: [PATCH 004/141] Update genkit.go --- go/genkit/genkit.go | 48 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 85936cadce..40e137a2b8 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -359,12 +359,50 @@ func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.St return core.DefineStreamingFlow(g.reg, name, fn) } -// DefineBidiFlow defines a bidirectional streaming flow, registers it, and -// returns a [core.Flow] that can be used to start bidi connections. +// DefineBidiFlow defines a bidirectional streaming flow, registers it as a [core.Action] of type Flow, +// and returns a [core.Flow] capable of bidirectional streaming. // -// The provided function receives initialization data, reads inputs from -// a channel, and writes streamed outputs to a channel. It returns a final -// output when complete. +// The provided function `fn` receives initialization data of type `Init`, reads +// inputs of type `In` from an input channel, and writes streamed outputs of type +// `Stream` to an output channel. It returns a final output of type `Out` when complete. +// +// Example: +// +// chatFlow := genkit.DefineBidiFlow(g, "chat", +// func(ctx context.Context, init struct{}, inCh <-chan string, outCh chan<- string) (string, error) { +// var count int +// for input := range inCh { +// count++ +// outCh <- fmt.Sprintf("reply: %s", input) +// } +// return fmt.Sprintf("processed %d messages", count), nil +// }, +// ) +// +// // Start a bidi connection: +// conn, err := chatFlow.StreamBidi(ctx, struct{}{}) +// if err != nil { +// // handle error +// } +// +// // Send messages concurrently: +// go func() { +// conn.Send("hello") +// conn.Send("world") +// conn.Close() +// }() +// +// // Receive streamed responses: +// for chunk, err := range conn.Receive() { +// if err != nil { +// // handle error +// } +// fmt.Println(chunk) // Outputs: "reply: hello", "reply: world" +// } +// +// // Get the final output: +// output, err := conn.Output() +// fmt.Println(output) // Output: "processed 2 messages" func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.BidiFunc[In, Out, Stream, Init]) *core.Flow[In, Out, Stream, Init] { return core.DefineBidiFlow(g.reg, name, fn) } From 1382035ebf595777b7362d7aa61ad7586ee6a489 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 5 Feb 2026 21:46:49 -0800 Subject: [PATCH 005/141] added `SessionFlow` and related --- go/ai/x/option.go | 98 ++++ go/ai/x/session_flow.go | 660 +++++++++++++++++++++++ go/ai/x/session_flow_test.go | 746 ++++++++++++++++++++++++++ go/ai/x/snapshot.go | 161 ++++++ go/core/api/action.go | 2 + go/genkit/session_flow.go | 60 +++ go/samples/basic-session-flow/main.go | 122 +++++ 7 files changed, 1849 insertions(+) create mode 100644 go/ai/x/option.go create mode 100644 go/ai/x/session_flow.go create mode 100644 go/ai/x/session_flow_test.go create mode 100644 go/ai/x/snapshot.go create mode 100644 go/genkit/session_flow.go create mode 100644 go/samples/basic-session-flow/main.go diff --git a/go/ai/x/option.go b/go/ai/x/option.go new file mode 100644 index 0000000000..1ab09a9cbf --- /dev/null +++ b/go/ai/x/option.go @@ -0,0 +1,98 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import "errors" + +// --- SessionFlowOption --- + +// SessionFlowOption configures a SessionFlow. +type SessionFlowOption[State any] interface { + applySessionFlow(*sessionFlowOptions[State]) error +} + +type sessionFlowOptions[State any] struct { + store SnapshotStore[State] + callback SnapshotCallback[State] +} + +func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error { + if o.store != nil { + if opts.store != nil { + return errors.New("cannot set snapshot store more than once (WithSnapshotStore)") + } + opts.store = o.store + } + if o.callback != nil { + if opts.callback != nil { + return errors.New("cannot set snapshot callback more than once (WithSnapshotCallback)") + } + opts.callback = o.callback + } + return nil +} + +// WithSnapshotStore sets the store for persisting snapshots. +func WithSnapshotStore[State any](store SnapshotStore[State]) SessionFlowOption[State] { + return &sessionFlowOptions[State]{store: store} +} + +// WithSnapshotCallback configures when snapshots are created. +// If not provided and a store is configured, snapshots are always created. +func WithSnapshotCallback[State any](cb SnapshotCallback[State]) SessionFlowOption[State] { + return &sessionFlowOptions[State]{callback: cb} +} + +// --- StreamBidiOption --- + +// StreamBidiOption configures a StreamBidi call. +type StreamBidiOption[State any] interface { + applyStreamBidi(*streamBidiOptions[State]) error +} + +type streamBidiOptions[State any] struct { + state *SessionState[State] + snapshotID string +} + +func (o *streamBidiOptions[State]) applyStreamBidi(opts *streamBidiOptions[State]) error { + if o.state != nil { + if opts.state != nil { + return errors.New("cannot set state more than once (WithState)") + } + opts.state = o.state + } + if o.snapshotID != "" { + if opts.snapshotID != "" { + return errors.New("cannot set snapshot ID more than once (WithSnapshotID)") + } + opts.snapshotID = o.snapshotID + } + return nil +} + +// WithState sets the initial state for the invocation. +// Use this for client-managed state where the client sends state directly. +func WithState[State any](state *SessionState[State]) StreamBidiOption[State] { + return &streamBidiOptions[State]{state: state} +} + +// WithSnapshotID loads state from a persisted snapshot by ID. +// Use this for server-managed state where snapshots are stored. +func WithSnapshotID[State any](id string) StreamBidiOption[State] { + return &streamBidiOptions[State]{snapshotID: id} +} diff --git a/go/ai/x/session_flow.go b/go/ai/x/session_flow.go new file mode 100644 index 0000000000..29c0648549 --- /dev/null +++ b/go/ai/x/session_flow.go @@ -0,0 +1,660 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package aix provides experimental AI primitives for Genkit. +// +// APIs in this package are under active development and may change in any +// minor version release. +package aix + +import ( + "context" + "encoding/json" + "fmt" + "iter" + "log/slog" + "sync" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/tracing" + "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" +) + +// SessionFlowArtifact represents a named collection of parts produced during a session. +// Examples: generated files, images, code snippets, diagrams, etc. +type SessionFlowArtifact struct { + // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). + Name string `json:"name,omitempty"` + // Parts contains the artifact content (text, media, etc.). + Parts []*ai.Part `json:"parts"` + // Metadata contains additional artifact-specific data. + Metadata map[string]any `json:"metadata,omitempty"` +} + +// SessionFlowInput is the input sent to a session flow during a conversation turn. +type SessionFlowInput struct { + // Messages contains the user's input for this turn. + Messages []*ai.Message `json:"messages,omitempty"` +} + +// SessionFlowInit is the input for starting a session flow invocation. +// Provide either SnapshotID (to load from store) or State (direct state). +type SessionFlowInit[State any] struct { + // SnapshotID loads state from a persisted snapshot. + // Mutually exclusive with State. + SnapshotID string `json:"snapshotId,omitempty"` + // State provides direct state for the invocation. + // Mutually exclusive with SnapshotID. + State *SessionState[State] `json:"state,omitempty"` +} + +// SessionFlowOutput is the output when a session flow invocation completes. +type SessionFlowOutput[State any] struct { + // SnapshotID is the ID of the snapshot created at the end of this invocation. + // Empty if no snapshot was created (callback returned false or no store configured). + SnapshotID string `json:"snapshotId,omitempty"` + // State contains the final conversation state. + State *SessionState[State] `json:"state"` +} + +// SessionFlowStreamChunk represents a single item in the session flow's output stream. +// Multiple fields can be populated in a single chunk. +type SessionFlowStreamChunk[Stream any] struct { + // Chunk contains token-level generation data. + Chunk *ai.ModelResponseChunk `json:"chunk,omitempty"` + // Status contains user-defined structured status information. + // The Stream type parameter defines the shape of this data. + Status Stream `json:"status,omitempty"` + // Artifact contains a newly produced artifact. + Artifact *SessionFlowArtifact `json:"artifact,omitempty"` + // SnapshotCreated contains the ID of a snapshot that was just persisted. + SnapshotCreated string `json:"snapshotCreated,omitempty"` + // EndTurn signals that the session flow has finished processing the current input. + // When true, the client should stop iterating and may send the next input. + EndTurn bool `json:"endTurn,omitempty"` +} + +// --- Session --- + +// Session holds the working state during a session flow invocation. +// It is propagated through context and provides read/write access to state. +type Session[State any] struct { + mu sync.RWMutex + state SessionState[State] + store SnapshotStore[State] + + snapshotCallback SnapshotCallback[State] + + // onEndTurn is set by the framework; triggers snapshot + EndTurn chunk. + onEndTurn func(ctx context.Context) + inCh <-chan *SessionFlowInput + + // Snapshot tracking + lastSnapshot *SessionSnapshot[State] + turnIndex int +} + +// Run loops over the input channel, calling fn for each turn. Each turn is +// wrapped in a trace span for observability. Input messages are automatically +// added to the session before fn is called. After fn returns successfully, an +// EndTurn chunk is sent and a snapshot check is triggered. +func (s *Session[State]) Run( + ctx context.Context, + fn func(ctx context.Context, input *SessionFlowInput) error, +) error { + for input := range s.inCh { + spanMeta := &tracing.SpanMetadata{ + Name: fmt.Sprintf("sessionFlow/turn/%d", s.turnIndex), + Type: "sessionFlowTurn", + Subtype: "sessionFlowTurn", + } + + _, err := tracing.RunInNewSpan(ctx, spanMeta, input, + func(ctx context.Context, input *SessionFlowInput) (struct{}, error) { + s.AddMessages(input.Messages...) + + if err := fn(ctx, input); err != nil { + return struct{}{}, err + } + + s.onEndTurn(ctx) + s.turnIndex++ + return struct{}{}, nil + }, + ) + if err != nil { + return err + } + } + return nil +} + +// State returns a copy of the current session flow state. +func (s *Session[State]) State() *SessionState[State] { + s.mu.RLock() + defer s.mu.RUnlock() + copied := s.copyStateLocked() + return &copied +} + +// Messages returns the current conversation history. +func (s *Session[State]) Messages() []*ai.Message { + s.mu.RLock() + defer s.mu.RUnlock() + msgs := make([]*ai.Message, len(s.state.Messages)) + copy(msgs, s.state.Messages) + return msgs +} + +// AddMessages appends messages to the conversation history. +func (s *Session[State]) AddMessages(messages ...*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = append(s.state.Messages, messages...) +} + +// SetMessages replaces the entire conversation history. +func (s *Session[State]) SetMessages(messages []*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = messages +} + +// Custom returns the current user-defined custom state. +func (s *Session[State]) Custom() State { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.Custom +} + +// SetCustom updates the user-defined custom state. +func (s *Session[State]) SetCustom(custom State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Custom = custom +} + +// UpdateCustom atomically reads the current custom state, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateCustom(fn func(State) State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Custom = fn(s.state.Custom) +} + +// Artifacts returns the current artifacts. +func (s *Session[State]) Artifacts() []*SessionFlowArtifact { + s.mu.RLock() + defer s.mu.RUnlock() + arts := make([]*SessionFlowArtifact, len(s.state.Artifacts)) + copy(arts, s.state.Artifacts) + return arts +} + +// AddArtifacts adds artifacts to the session. If an artifact with the same +// name already exists, it is replaced. +func (s *Session[State]) AddArtifacts(artifacts ...*SessionFlowArtifact) { + s.mu.Lock() + defer s.mu.Unlock() + for _, a := range artifacts { + replaced := false + if a.Name != "" { + for i, existing := range s.state.Artifacts { + if existing.Name == a.Name { + s.state.Artifacts[i] = a + replaced = true + break + } + } + } + if !replaced { + s.state.Artifacts = append(s.state.Artifacts, a) + } + } +} + +// SetArtifacts replaces the entire artifact list. +func (s *Session[State]) SetArtifacts(artifacts []*SessionFlowArtifact) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Artifacts = artifacts +} + +// maybeSnapshot creates a snapshot if conditions are met (store configured, +// callback approves). Returns the snapshot ID or empty string. +func (s *Session[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { + if s.store == nil { + return "" + } + + s.mu.RLock() + currentState := s.copyStateLocked() + turnIndex := s.turnIndex + s.mu.RUnlock() + + shouldSnapshot := true + if s.snapshotCallback != nil { + var prevState *SessionState[State] + if s.lastSnapshot != nil { + prevState = &s.lastSnapshot.State + } + shouldSnapshot = s.snapshotCallback(ctx, &SnapshotContext[State]{ + State: ¤tState, + PrevState: prevState, + TurnIndex: turnIndex, + Event: event, + }) + } + + if !shouldSnapshot { + return "" + } + + snapshot := &SessionSnapshot[State]{ + SnapshotID: uuid.New().String(), + CreatedAt: time.Now(), + TurnIndex: turnIndex, + Event: event, + State: currentState, + } + if s.lastSnapshot != nil { + snapshot.ParentID = s.lastSnapshot.SnapshotID + } + + if err := s.store.SaveSnapshot(ctx, snapshot); err != nil { + slog.Error("session flow: failed to save snapshot", "err", err) + return "" + } + + // Set snapshotId in last message metadata. + s.mu.Lock() + if msgs := s.state.Messages; len(msgs) > 0 { + lastMsg := msgs[len(msgs)-1] + if lastMsg.Metadata == nil { + lastMsg.Metadata = make(map[string]any) + } + lastMsg.Metadata["snapshotId"] = snapshot.SnapshotID + } + s.mu.Unlock() + + s.lastSnapshot = snapshot + + // Record on OTel span. + span := oteltrace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("genkit:metadata:snapshotId", snapshot.SnapshotID), + ) + + return snapshot.SnapshotID +} + +// copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). +func (s *Session[State]) copyStateLocked() SessionState[State] { + bytes, err := json.Marshal(s.state) + if err != nil { + panic(fmt.Sprintf("session flow: failed to marshal state: %v", err)) + } + var copied SessionState[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + panic(fmt.Sprintf("session flow: failed to unmarshal state: %v", err)) + } + return copied +} + +// --- Session context --- + +type sessionContextKey struct{} + +type sessionHolder struct { + session any +} + +// NewSessionContext returns a new context with the session attached. +func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { + return context.WithValue(ctx, sessionContextKey{}, &sessionHolder{session: s}) +} + +// SessionFromContext retrieves the current session from context. +// Returns nil if no session is in context or if the type doesn't match. +func SessionFromContext[State any](ctx context.Context) *Session[State] { + holder, ok := ctx.Value(sessionContextKey{}).(*sessionHolder) + if !ok || holder == nil { + return nil + } + session, ok := holder.session.(*Session[State]) + if !ok { + return nil + } + return session +} + +// --- Responder --- + +// Responder is the output channel for a session flow. Chunks sent through it +// are automatically inspected: if a chunk contains an artifact, it is added to +// the session before being forwarded to the client. +// +// Convenience methods are provided for common chunk types. +type Responder[Stream any] chan<- *SessionFlowStreamChunk[Stream] + +// SendChunk sends a generation chunk (token-level streaming). +func (r Responder[Stream]) SendChunk(chunk *ai.ModelResponseChunk) { + r <- &SessionFlowStreamChunk[Stream]{Chunk: chunk} +} + +// SendStatus sends a user-defined status update. +func (r Responder[Stream]) SendStatus(status Stream) { + r <- &SessionFlowStreamChunk[Stream]{Status: status} +} + +// SendArtifact sends an artifact to the stream and adds it to the session. +// If an artifact with the same name already exists in the session, it is replaced. +func (r Responder[Stream]) SendArtifact(artifact *SessionFlowArtifact) { + r <- &SessionFlowStreamChunk[Stream]{Artifact: artifact} +} + +// --- SessionFlowParams --- + +// SessionFlowParams contains the parameters passed to a session flow function. +type SessionFlowParams[State any] struct { + // Session provides access to the working state. + Session *Session[State] +} + +// --- SessionFlowFunc --- + +// SessionFlowFunc is the function signature for session flows. +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +type SessionFlowFunc[Stream, State any] func( + ctx context.Context, + resp Responder[Stream], + params *SessionFlowParams[State], +) error + +// --- SessionFlow --- + +// SessionFlow is a bidirectional streaming flow with automatic snapshot management. +type SessionFlow[Stream, State any] struct { + flow *core.Flow[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream], *SessionFlowInit[State]] + store SnapshotStore[State] + snapshotCallback SnapshotCallback[State] +} + +// DefineSessionFlow creates a SessionFlow with automatic snapshot management and registers it. +func DefineSessionFlow[Stream, State any]( + r api.Registry, + name string, + fn SessionFlowFunc[Stream, State], + opts ...SessionFlowOption[State], +) *SessionFlow[Stream, State] { + sfOpts := &sessionFlowOptions[State]{} + for _, opt := range opts { + if err := opt.applySessionFlow(sfOpts); err != nil { + panic(fmt.Errorf("DefineSessionFlow %q: %w", name, err)) + } + } + + sf := &SessionFlow[Stream, State]{ + store: sfOpts.store, + snapshotCallback: sfOpts.callback, + } + + bidiFn := func( + ctx context.Context, + init *SessionFlowInit[State], + inCh <-chan *SessionFlowInput, + outCh chan<- *SessionFlowStreamChunk[Stream], + ) (*SessionFlowOutput[State], error) { + return sf.runWrapped(ctx, init, inCh, outCh, fn) + } + + sf.flow = core.DefineBidiFlow(r, name, bidiFn) + + // Register snapshot store action for reflection API. + if sfOpts.store != nil { + registerSnapshotStoreAction(r, name, sfOpts.store) + } + + return sf +} + +// StreamBidi starts a new session flow invocation. +func (sf *SessionFlow[Stream, State]) StreamBidi( + ctx context.Context, + opts ...StreamBidiOption[State], +) (*SessionFlowConnection[Stream, State], error) { + sbOpts := &streamBidiOptions[State]{} + for _, opt := range opts { + if err := opt.applyStreamBidi(sbOpts); err != nil { + return nil, fmt.Errorf("SessionFlow.StreamBidi %q: %w", sf.flow.Name(), err) + } + } + + init := &SessionFlowInit[State]{ + SnapshotID: sbOpts.snapshotID, + State: sbOpts.state, + } + + conn, err := sf.flow.StreamBidi(ctx, init) + if err != nil { + return nil, err + } + + return &SessionFlowConnection[Stream, State]{conn: conn}, nil +} + +// runWrapped is the BidiFunc implementation. It sets up the session, +// responder, and wiring, then delegates to the user's function. +func (sf *SessionFlow[Stream, State]) runWrapped( + ctx context.Context, + init *SessionFlowInit[State], + inCh <-chan *SessionFlowInput, + outCh chan<- *SessionFlowStreamChunk[Stream], + fn SessionFlowFunc[Stream, State], +) (*SessionFlowOutput[State], error) { + session, err := newSessionFromInit(ctx, init, sf.store, sf.snapshotCallback) + if err != nil { + return nil, err + } + session.inCh = inCh + ctx = NewSessionContext(ctx, session) + + // Intermediary channel: intercepts artifacts before forwarding to outCh. + respCh := make(chan *SessionFlowStreamChunk[Stream]) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for chunk := range respCh { + if chunk.Artifact != nil { + session.AddArtifacts(chunk.Artifact) + } + outCh <- chunk + } + }() + + // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. + // Writes through respCh to preserve ordering with user chunks. + session.onEndTurn = func(turnCtx context.Context) { + snapshotID := session.maybeSnapshot(turnCtx, TurnEnd) + if snapshotID != "" { + respCh <- &SessionFlowStreamChunk[Stream]{SnapshotCreated: snapshotID} + } + respCh <- &SessionFlowStreamChunk[Stream]{EndTurn: true} + } + + params := &SessionFlowParams[State]{ + Session: session, + } + + fnErr := fn(ctx, Responder[Stream](respCh), params) + close(respCh) + wg.Wait() + + if fnErr != nil { + return nil, fnErr + } + + // Final snapshot at invocation end. + snapshotID := session.maybeSnapshot(ctx, InvocationEnd) + + return &SessionFlowOutput[State]{ + State: session.State(), + SnapshotID: snapshotID, + }, nil +} + +// newSessionFromInit creates a session from initialization data. +func newSessionFromInit[State any]( + ctx context.Context, + init *SessionFlowInit[State], + store SnapshotStore[State], + cb SnapshotCallback[State], +) (*Session[State], error) { + s := &Session[State]{ + store: store, + snapshotCallback: cb, + } + + if init != nil { + if init.SnapshotID != "" && store != nil { + snapshot, err := store.GetSnapshot(ctx, init.SnapshotID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) + } + if snapshot == nil { + return nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) + } + s.state = snapshot.State + s.lastSnapshot = snapshot + s.turnIndex = snapshot.TurnIndex + } else if init.State != nil { + s.state = *init.State + } + } + + return s, nil +} + +// --- Snapshot store reflection action --- + +type getSnapshotInput struct { + SnapshotID string `json:"snapshotId"` +} + +func registerSnapshotStoreAction[State any](r api.Registry, flowName string, store SnapshotStore[State]) { + core.DefineAction(r, flowName+"/getSnapshot", api.ActionTypeSnapshotStore, nil, nil, + func(ctx context.Context, input getSnapshotInput) (*SessionSnapshot[State], error) { + return store.GetSnapshot(ctx, input.SnapshotID) + }, + ) +} + +// --- SessionFlowConnection --- + +// SessionFlowConnection wraps BidiConnection with session flow-specific functionality. +// It provides a Receive() iterator that supports multi-turn patterns: breaking out +// of the iterator between turns does not cancel the underlying connection. +type SessionFlowConnection[Stream, State any] struct { + conn *core.BidiConnection[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream]] + + // chunks buffers stream chunks from the underlying connection so that + // breaking from Receive() between turns doesn't cancel the context. + chunks chan *SessionFlowStreamChunk[Stream] + chunkErr error + initOnce sync.Once +} + +// initReceiver starts a goroutine that drains the underlying BidiConnection's +// Receive into a channel. This goroutine never breaks from the underlying +// iterator, preventing context cancellation. +func (c *SessionFlowConnection[Stream, State]) initReceiver() { + c.initOnce.Do(func() { + c.chunks = make(chan *SessionFlowStreamChunk[Stream], 1) + go func() { + defer close(c.chunks) + for chunk, err := range c.conn.Receive() { + if err != nil { + c.chunkErr = err + return + } + c.chunks <- chunk + } + }() + }) +} + +// Send sends a SessionFlowInput to the session flow. +func (c *SessionFlowConnection[Stream, State]) Send(input *SessionFlowInput) error { + return c.conn.Send(input) +} + +// SendMessages sends messages to the session flow. +func (c *SessionFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { + return c.conn.Send(&SessionFlowInput{Messages: messages}) +} + +// SendText sends a single user text message to the session flow. +func (c *SessionFlowConnection[Stream, State]) SendText(text string) error { + return c.conn.Send(&SessionFlowInput{ + Messages: []*ai.Message{ai.NewUserTextMessage(text)}, + }) +} + +// Close signals that no more inputs will be sent. +func (c *SessionFlowConnection[Stream, State]) Close() error { + return c.conn.Close() +} + +// Receive returns an iterator for receiving stream chunks. +// Unlike the underlying BidiConnection.Receive, breaking out of this iterator +// does not cancel the connection. This enables multi-turn patterns where the +// caller breaks on EndTurn, sends the next input, then calls Receive again. +func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { + c.initReceiver() + return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { + for { + chunk, ok := <-c.chunks + if !ok { + if err := c.chunkErr; err != nil { + var zero *SessionFlowStreamChunk[Stream] + yield(zero, err) + } + return + } + if !yield(chunk, nil) { + return + } + } + } +} + +// Output returns the final response after the session flow completes. +func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[State], error) { + return c.conn.Output() +} + +// Done returns a channel closed when the connection completes. +func (c *SessionFlowConnection[Stream, State]) Done() <-chan struct{} { + return c.conn.Done() +} diff --git a/go/ai/x/session_flow_test.go b/go/ai/x/session_flow_test.go new file mode 100644 index 0000000000..d5722863fd --- /dev/null +++ b/go/ai/x/session_flow_test.go @@ -0,0 +1,746 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import ( + "context" + "fmt" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/registry" +) + +type testState struct { + Counter int `json:"counter"` + Topics []string `json:"topics,omitempty"` +} + +type testStatus struct { + Phase string `json:"phase"` +} + +func newTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + return registry.New() +} + +func TestSessionFlow_BasicMultiTurn(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "basicFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + resp.SendStatus(testStatus{Phase: "generating"}) + // Echo back the user's message. + if len(input.Messages) > 0 { + reply := ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text) + sess.AddMessages(reply) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + resp.SendStatus(testStatus{Phase: "complete"}) + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + var turn1Chunks int + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + turn1Chunks++ + if chunk.EndTurn { + break + } + } + if turn1Chunks < 2 { // at least status + endTurn + t.Errorf("expected at least 2 chunks in turn 1, got %d", turn1Chunks) + } + + // Turn 2. + if err := conn.SendText("world"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 user messages + 2 echo replies = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + if got := response.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } +} + +func TestSessionFlow_WithSnapshotStore(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + sf := DefineSessionFlow(reg, "snapshotFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSnapshotStore[testState](store), + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("turn1") + + var snapshotIDs []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.SnapshotCreated != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotCreated) + } + if chunk.EndTurn { + break + } + } + + if len(snapshotIDs) != 1 { + t.Fatalf("expected 1 snapshot from turn, got %d", len(snapshotIDs)) + } + + // Verify the snapshot was persisted. + snap, err := store.GetSnapshot(ctx, snapshotIDs[0]) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap == nil { + t.Fatal("expected snapshot, got nil") + } + if snap.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 in snapshot, got %d", snap.State.Custom.Counter) + } + if snap.TurnIndex != 0 { + t.Errorf("expected turnIndex=0, got %d", snap.TurnIndex) + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Final snapshot at invocation end. + if response.SnapshotID == "" { + t.Error("expected final snapshot ID") + } +} + +func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + sf := DefineSessionFlow(reg, "resumeFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSnapshotStore[testState](store), + ) + + // First invocation: create a snapshot. + conn1, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + conn1.SendText("first message") + for chunk, err := range conn1.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn1.Close() + resp1, err := conn1.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + if resp1.SnapshotID == "" { + t.Fatal("expected snapshot ID from first invocation") + } + + // Second invocation: resume from snapshot. + conn2, err := sf.StreamBidi(ctx, WithSnapshotID[testState](resp1.SnapshotID)) + if err != nil { + t.Fatalf("StreamBidi with snapshot failed: %v", err) + } + conn2.SendText("continued message") + for chunk, err := range conn2.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn2.Close() + resp2, err := conn2.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have messages from both invocations: + // first: user + reply (2) + second: user + reply (2) = 4. + if got := len(resp2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + // Counter should be 2 (1 from first + 1 from second). + if got := resp2.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } + + // The new snapshot should reference the previous as parent. + if resp2.SnapshotID == "" { + t.Fatal("expected snapshot ID from second invocation") + } + snap2, err := store.GetSnapshot(ctx, resp2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + // The parent chain: snap2's parent is a turn-end snapshot from the second invocation, + // which itself has a parent from the first invocation's final snapshot. + // We just verify that the parent chain exists (not empty). + if snap2.ParentID == "" { + t.Error("expected parent ID on resumed snapshot") + } +} + +func TestSessionFlow_ClientManagedState(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "clientStateFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + ) + + // Start with client-provided state. + clientState := &SessionState[testState]{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("previous message"), + ai.NewModelTextMessage("previous reply"), + }, + Custom: testState{Counter: 5}, + } + + conn, err := sf.StreamBidi(ctx, WithState(clientState)) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("new message") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 previous + 1 new user + 1 reply = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + // Counter should be 6 (started at 5, incremented once). + if got := response.State.Custom.Counter; got != 6 { + t.Errorf("expected counter=6, got %d", got) + } + // No snapshot since no store was configured. + if response.SnapshotID != "" { + t.Errorf("expected no snapshot ID without store, got %q", response.SnapshotID) + } +} + +func TestSessionFlow_Artifacts(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "artifactFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + + resp.SendArtifact(&SessionFlowArtifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("package main")}, + }) + + // Replace artifact with same name. + resp.SendArtifact(&SessionFlowArtifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("package main\nfunc main() {}")}, + }) + + // Add another artifact. + resp.SendArtifact(&SessionFlowArtifact{ + Name: "readme.md", + Parts: []*ai.Part{ai.NewTextPart("# README")}, + }) + + sess.AddMessages(ai.NewModelTextMessage("done")) + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("generate code") + var receivedArtifacts []*SessionFlowArtifact + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Artifact != nil { + receivedArtifacts = append(receivedArtifacts, chunk.Artifact) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + if len(receivedArtifacts) != 3 { // all 3 sends are streamed + t.Errorf("expected 3 streamed artifacts, got %d", len(receivedArtifacts)) + } + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should have 2 unique artifacts (code.go was replaced). + if got := len(response.State.Artifacts); got != 2 { + t.Errorf("expected 2 artifacts in state, got %d", got) + } +} + +func TestSessionFlow_SnapshotCallback(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + // Only snapshot on even turns. + callbackCalls := 0 + sf := DefineSessionFlow(reg, "callbackFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSnapshotStore[testState](store), + WithSnapshotCallback(func(ctx context.Context, sc *SnapshotContext[testState]) bool { + callbackCalls++ + return sc.TurnIndex%2 == 0 // only snapshot on even turns + }), + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + var snapshotIDs []string + for i := 0; i < 3; i++ { + conn.SendText(fmt.Sprintf("turn %d", i)) + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error on turn %d: %v", i, err) + } + if chunk.SnapshotCreated != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotCreated) + } + if chunk.EndTurn { + break + } + } + } + conn.Close() + conn.Output() // drain + + // Turn 0 (even) → snapshot, Turn 1 (odd) → no, Turn 2 (even) → snapshot. + // That's 2 turn snapshots from the callback. + if got := len(snapshotIDs); got != 2 { + t.Errorf("expected 2 turn snapshots, got %d", got) + } + // Callback should have been called 3 times (once per turn). + if callbackCalls < 3 { + t.Errorf("expected at least 3 callback calls, got %d", callbackCalls) + } +} + +func TestSessionFlow_SendMessages(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "sendMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Send multiple messages at once. + err = conn.SendMessages( + ai.NewUserTextMessage("msg1"), + ai.NewUserTextMessage("msg2"), + ) + if err != nil { + t.Fatalf("SendMessages failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Both messages should have been added. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + } +} + +func TestSessionFlow_SessionContext(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + var retrievedCounter int + sf := DefineSessionFlow(reg, "contextFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + // Session should be retrievable from context. + sess := SessionFromContext[testState](ctx) + if sess == nil { + t.Error("expected session from context") + return nil + } + sess.UpdateCustom(func(s testState) testState { + s.Counter = 42 + return s + }) + retrievedCounter = sess.Custom().Counter + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("test") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + conn.Output() + + if retrievedCounter != 42 { + t.Errorf("expected counter=42 from context, got %d", retrievedCounter) + } +} + +func TestSessionFlow_ErrorInTurn(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "errorFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return fmt.Errorf("turn failed") + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("trigger error") + conn.Close() + + _, err = conn.Output() + if err == nil { + t.Fatal("expected error from failed turn") + } +} + +func TestSessionFlow_SetMessages(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "setMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + // Replace all messages with just one. + sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("original") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // SetMessages replaced everything with 1 message. + if got := len(response.State.Messages); got != 1 { + t.Errorf("expected 1 message after SetMessages, got %d", got) + } +} + +func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + sf := DefineSessionFlow(reg, "metadataFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + WithSnapshotStore[testState](store), + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // The last message should have snapshotId in its metadata. + msgs := response.State.Messages + if len(msgs) == 0 { + t.Fatal("expected messages in response") + } + lastMsg := msgs[len(msgs)-1] + if lastMsg.Metadata == nil { + t.Fatal("expected metadata on last message") + } + if _, ok := lastMsg.Metadata["snapshotId"]; !ok { + t.Error("expected snapshotId in last message metadata") + } +} + +func TestInMemorySnapshotStore(t *testing.T) { + ctx := context.Background() + store := NewInMemorySnapshotStore[testState]() + + // Get non-existent. + snap, err := store.GetSnapshot(ctx, "nonexistent") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap != nil { + t.Errorf("expected nil, got %v", snap) + } + + // Save and retrieve. + snapshot := &SessionSnapshot[testState]{ + SnapshotID: "snap-1", + TurnIndex: 0, + State: SessionState[testState]{ + Custom: testState{Counter: 1}, + }, + } + if err := store.SaveSnapshot(ctx, snapshot); err != nil { + t.Fatalf("SaveSnapshot failed: %v", err) + } + + retrieved, err := store.GetSnapshot(ctx, "snap-1") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if retrieved == nil { + t.Fatal("expected snapshot") + } + if retrieved.State.Custom.Counter != 1 { + t.Errorf("expected counter=1, got %d", retrieved.State.Custom.Counter) + } + + // Verify isolation. + snapshot.State.Custom.Counter = 999 + retrieved2, _ := store.GetSnapshot(ctx, "snap-1") + if retrieved2.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) + } +} + +func TestSessionFlow_SnapshotStoreReflectionAction(t *testing.T) { + _ = context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + DefineSessionFlow(reg, "reflectFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return nil + }, + WithSnapshotStore[testState](store), + ) + + // The getSnapshot action should be registered. + action := reg.LookupAction("/snapshot-store/reflectFlow/getSnapshot") + if action == nil { + t.Fatal("expected getSnapshot action to be registered") + } +} diff --git a/go/ai/x/snapshot.go b/go/ai/x/snapshot.go new file mode 100644 index 0000000000..c81bfcff81 --- /dev/null +++ b/go/ai/x/snapshot.go @@ -0,0 +1,161 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/firebase/genkit/go/ai" +) + +// SessionState is the portable conversation state that flows between client +// and server. It contains only the data needed for conversation continuity. +type SessionState[State any] struct { + // Messages is the conversation history (user/model exchanges). + // Does NOT include prompt-rendered messages — those are rendered fresh each turn. + Messages []*ai.Message `json:"messages,omitempty"` + // Custom is the user-defined state associated with this conversation. + Custom State `json:"custom,omitempty"` + // Artifacts are named collections of parts produced during the conversation. + Artifacts []*SessionFlowArtifact `json:"artifacts,omitempty"` +} + +// SnapshotEvent identifies what triggered a snapshot. +type SnapshotEvent string + +const ( + // TurnEnd indicates the snapshot was triggered at the end of a turn. + TurnEnd SnapshotEvent = "turnEnd" + // InvocationEnd indicates the snapshot was triggered at the end of the invocation. + InvocationEnd SnapshotEvent = "invocationEnd" +) + +// SessionSnapshot is a persisted point-in-time capture of session state. +type SessionSnapshot[State any] struct { + // SnapshotID is the unique identifier for this snapshot (UUID). + SnapshotID string `json:"snapshotId"` + // ParentID is the ID of the previous snapshot in this timeline. + ParentID string `json:"parentId,omitempty"` + // CreatedAt is when the snapshot was created. + CreatedAt time.Time `json:"createdAt"` + // TurnIndex is the turn number when this snapshot was created (0-indexed). + TurnIndex int `json:"turnIndex"` + // Event is what triggered this snapshot. + Event SnapshotEvent `json:"event"` + // State is the actual conversation state. + State SessionState[State] `json:"state"` +} + +// SnapshotContext provides context for snapshot decision callbacks. +type SnapshotContext[State any] struct { + // State is the current state that will be snapshotted if the callback returns true. + State *SessionState[State] + // PrevState is the state at the last snapshot, or nil if none exists. + PrevState *SessionState[State] + // TurnIndex is the current turn number. + TurnIndex int + // Event is what triggered this snapshot check. + Event SnapshotEvent +} + +// SnapshotCallback decides whether to create a snapshot. +// If not provided and a store is configured, snapshots are always created. +type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool + +// SnapshotStore persists and retrieves snapshots. +type SnapshotStore[State any] interface { + // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. + GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) + // SaveSnapshot persists a snapshot. + SaveSnapshot(ctx context.Context, snapshot *SessionSnapshot[State]) error +} + +// InMemorySnapshotStore provides a thread-safe in-memory snapshot store. +type InMemorySnapshotStore[State any] struct { + snapshots map[string]*SessionSnapshot[State] + mu sync.RWMutex +} + +// NewInMemorySnapshotStore creates a new in-memory snapshot store. +func NewInMemorySnapshotStore[State any]() *InMemorySnapshotStore[State] { + return &InMemorySnapshotStore[State]{ + snapshots: make(map[string]*SessionSnapshot[State]), + } +} + +// GetSnapshot retrieves a snapshot by ID. Returns nil if not found. +func (s *InMemorySnapshotStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { + s.mu.RLock() + defer s.mu.RUnlock() + + snap, exists := s.snapshots[snapshotID] + if !exists { + return nil, nil + } + + copied, err := copySnapshot(snap) + if err != nil { + return nil, err + } + return copied, nil +} + +// SaveSnapshot persists a snapshot. +func (s *InMemorySnapshotStore[State]) SaveSnapshot(_ context.Context, snapshot *SessionSnapshot[State]) error { + s.mu.Lock() + defer s.mu.Unlock() + + copied, err := copySnapshot(snapshot) + if err != nil { + return err + } + s.snapshots[copied.SnapshotID] = copied + return nil +} + +// copySnapshot creates a deep copy of a snapshot using JSON marshaling. +func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + if snap == nil { + return nil, nil + } + bytes, err := json.Marshal(snap) + if err != nil { + return nil, err + } + var copied SessionSnapshot[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + return nil, err + } + return &copied, nil +} + +// SnapshotOn returns a SnapshotCallback that only allows snapshots for the +// specified events. For example, SnapshotOn[MyState](TurnEnd) will skip the +// invocation-end snapshot. +func SnapshotOn[State any](events ...SnapshotEvent) SnapshotCallback[State] { + set := make(map[SnapshotEvent]struct{}, len(events)) + for _, e := range events { + set[e] = struct{}{} + } + return func(_ context.Context, sc *SnapshotContext[State]) bool { + _, ok := set[sc.Event] + return ok + } +} diff --git a/go/core/api/action.go b/go/core/api/action.go index a38958af51..3150fe88ae 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -64,6 +64,8 @@ const ( ActionTypeCustom ActionType = "custom" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" + ActionTypeSessionFlow ActionType = "session-flow" + ActionTypeSnapshotStore ActionType = "snapshot-store" ) // ActionDesc is a descriptor of an action. diff --git a/go/genkit/session_flow.go b/go/genkit/session_flow.go new file mode 100644 index 0000000000..6ff02e99e9 --- /dev/null +++ b/go/genkit/session_flow.go @@ -0,0 +1,60 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package genkit + +import ( + aix "github.com/firebase/genkit/go/ai/x" +) + +// DefineSessionFlow creates a SessionFlow with automatic snapshot management +// and registers it as a flow action. +// +// A SessionFlow is a stateful, multi-turn conversational flow with automatic +// snapshot persistence and turn semantics. It builds on bidirectional streaming +// to enable ongoing conversations with managed state. +// +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +// +// Example: +// +// type ChatState struct { +// TopicHistory []string `json:"topicHistory,omitempty"` +// } +// +// type ChatStatus struct { +// Phase string `json:"phase"` +// } +// +// chatFlow := genkit.DefineSessionFlow(g, "chatFlow", +// func(ctx context.Context, resp aix.Responder[ChatStatus], params *aix.SessionFlowParams[ChatState]) error { +// return params.Session.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { +// // ... handle each turn ... +// return nil +// }) +// }, +// aix.WithSnapshotStore(store), +// ) +func DefineSessionFlow[Stream, State any]( + g *Genkit, + name string, + fn aix.SessionFlowFunc[Stream, State], + opts ...aix.SessionFlowOption[State], +) *aix.SessionFlow[Stream, State] { + return aix.DefineSessionFlow(g.reg, name, fn, opts...) +} diff --git a/go/samples/basic-session-flow/main.go b/go/samples/basic-session-flow/main.go new file mode 100644 index 0000000000..7d1707a807 --- /dev/null +++ b/go/samples/basic-session-flow/main.go @@ -0,0 +1,122 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates the SessionFlow API for multi-turn conversation +// with token-level streaming. It runs a CLI REPL where conversation history +// is managed automatically by the session. +// +// To run: +// +// GOOGLE_GENAI_API_KEY=... go run . +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/x" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + store := aix.NewInMemorySnapshotStore[struct{}]() + + chatFlow := genkit.DefineSessionFlow(g, "chat", + func(ctx context.Context, resp aix.Responder[any], params *aix.SessionFlowParams[struct{}]) error { + sess := params.Session + return sess.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { + for chunk, err := range genkit.GenerateStream(ctx, g, + ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a helpful assistant. Keep responses concise."), + ai.WithMessages(sess.Messages()...), + ) { + if err != nil { + return err + } + if chunk.Done { + sess.AddMessages(chunk.Response.Message) + break + } + resp.SendChunk(chunk.Chunk) + } + + return nil + }) + }, + aix.WithSnapshotStore(store), + aix.WithSnapshotCallback(aix.SnapshotOn[struct{}](aix.TurnEnd)), + ) + + fmt.Println("Session Flow Chat (type 'quit' to exit)") + fmt.Println() + + conn, err := chatFlow.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.Chunk != nil { + fmt.Print(chunk.Chunk.Text()) + } + if chunk.SnapshotCreated != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotCreated) + } + if chunk.EndTurn { + fmt.Println("\n") + break + } + } + } + + conn.Close() +} From fe91d76ea9e9b751abb2859d9a2f7fcf63dac6ee Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 07:06:58 -0800 Subject: [PATCH 006/141] Update main.go --- go/samples/basic-session-flow/main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go/samples/basic-session-flow/main.go b/go/samples/basic-session-flow/main.go index 7d1707a807..9531604890 100644 --- a/go/samples/basic-session-flow/main.go +++ b/go/samples/basic-session-flow/main.go @@ -112,7 +112,8 @@ func main() { fmt.Printf("\n[snapshot: %s]", chunk.SnapshotCreated) } if chunk.EndTurn { - fmt.Println("\n") + fmt.Println() + fmt.Println() break } } From ad323d23eb0df8929d9c456dc21f143439084e78 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:46:10 -0800 Subject: [PATCH 007/141] Update main.go --- go/samples/basic-session-flow/main.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/go/samples/basic-session-flow/main.go b/go/samples/basic-session-flow/main.go index 9531604890..bc714ef409 100644 --- a/go/samples/basic-session-flow/main.go +++ b/go/samples/basic-session-flow/main.go @@ -15,10 +15,6 @@ // This sample demonstrates the SessionFlow API for multi-turn conversation // with token-level streaming. It runs a CLI REPL where conversation history // is managed automatically by the session. -// -// To run: -// -// GOOGLE_GENAI_API_KEY=... go run . package main import ( From 3913771d8a848339b7e4d8a17d4fc19e6d0aeb50 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 09:35:43 -0800 Subject: [PATCH 008/141] moved files --- go/genkit/genkit.go | 40 ++++++++++++++++++++++++++ go/genkit/session_flow.go | 60 --------------------------------------- 2 files changed, 40 insertions(+), 60 deletions(-) delete mode 100644 go/genkit/session_flow.go diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 40e137a2b8..94a717ba98 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -30,6 +30,7 @@ import ( "syscall" "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/x" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" @@ -407,6 +408,45 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B return core.DefineBidiFlow(g.reg, name, fn) } +// DefineSessionFlow creates a SessionFlow with automatic snapshot management +// and registers it as a flow action. +// +// A SessionFlow is a stateful, multi-turn conversational flow with automatic +// snapshot persistence and turn semantics. It builds on bidirectional streaming +// to enable ongoing conversations with managed state. +// +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +// +// Example: +// +// type ChatState struct { +// TopicHistory []string `json:"topicHistory,omitempty"` +// } +// +// type ChatStatus struct { +// Phase string `json:"phase"` +// } +// +// chatFlow := genkit.DefineSessionFlow(g, "chatFlow", +// func(ctx context.Context, resp aix.Responder[ChatStatus], params *aix.SessionFlowParams[ChatState]) error { +// return params.Session.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { +// // ... handle each turn ... +// return nil +// }) +// }, +// aix.WithSnapshotStore(store), +// ) +func DefineSessionFlow[Stream, State any]( + g *Genkit, + name string, + fn aix.SessionFlowFunc[Stream, State], + opts ...aix.SessionFlowOption[State], +) *aix.SessionFlow[Stream, State] { + return aix.DefineSessionFlow(g.reg, name, fn, opts...) +} + // Run executes the given function `fn` within the context of the current flow run, // creating a distinct trace span for this step. It's used to add observability // to specific sub-operations within a flow defined by [DefineFlow] or [DefineStreamingFlow]. diff --git a/go/genkit/session_flow.go b/go/genkit/session_flow.go deleted file mode 100644 index 6ff02e99e9..0000000000 --- a/go/genkit/session_flow.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package genkit - -import ( - aix "github.com/firebase/genkit/go/ai/x" -) - -// DefineSessionFlow creates a SessionFlow with automatic snapshot management -// and registers it as a flow action. -// -// A SessionFlow is a stateful, multi-turn conversational flow with automatic -// snapshot persistence and turn semantics. It builds on bidirectional streaming -// to enable ongoing conversations with managed state. -// -// Type parameters: -// - Stream: Type for status updates sent via the responder -// - State: Type for user-defined state in snapshots -// -// Example: -// -// type ChatState struct { -// TopicHistory []string `json:"topicHistory,omitempty"` -// } -// -// type ChatStatus struct { -// Phase string `json:"phase"` -// } -// -// chatFlow := genkit.DefineSessionFlow(g, "chatFlow", -// func(ctx context.Context, resp aix.Responder[ChatStatus], params *aix.SessionFlowParams[ChatState]) error { -// return params.Session.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { -// // ... handle each turn ... -// return nil -// }) -// }, -// aix.WithSnapshotStore(store), -// ) -func DefineSessionFlow[Stream, State any]( - g *Genkit, - name string, - fn aix.SessionFlowFunc[Stream, State], - opts ...aix.SessionFlowOption[State], -) *aix.SessionFlow[Stream, State] { - return aix.DefineSessionFlow(g.reg, name, fn, opts...) -} From 15880f58819edadf6c9355a7fbdd780fb6e49465 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 12:45:11 -0800 Subject: [PATCH 009/141] added `DefineSessionFlowFromPrompt` --- go/ai/x/option.go | 17 +- go/ai/x/prompt_session_flow_test.go | 375 ++++++++++++++++++ go/ai/x/session_flow.go | 93 ++++- go/ai/x/snapshot.go | 7 +- go/genkit/genkit.go | 21 + go/samples/basic-session-flow/main.go | 2 +- go/samples/prompt-session-flow/main.go | 101 +++++ .../prompt-session-flow/prompts/chat.prompt | 12 + 8 files changed, 619 insertions(+), 9 deletions(-) create mode 100644 go/ai/x/prompt_session_flow_test.go create mode 100644 go/samples/prompt-session-flow/main.go create mode 100644 go/samples/prompt-session-flow/prompts/chat.prompt diff --git a/go/ai/x/option.go b/go/ai/x/option.go index 1ab09a9cbf..d7380b4b83 100644 --- a/go/ai/x/option.go +++ b/go/ai/x/option.go @@ -65,8 +65,9 @@ type StreamBidiOption[State any] interface { } type streamBidiOptions[State any] struct { - state *SessionState[State] - snapshotID string + state *SessionState[State] + snapshotID string + promptInput any } func (o *streamBidiOptions[State]) applyStreamBidi(opts *streamBidiOptions[State]) error { @@ -82,6 +83,12 @@ func (o *streamBidiOptions[State]) applyStreamBidi(opts *streamBidiOptions[State } opts.snapshotID = o.snapshotID } + if o.promptInput != nil { + if opts.promptInput != nil { + return errors.New("cannot set prompt input more than once (WithPromptInput)") + } + opts.promptInput = o.promptInput + } return nil } @@ -96,3 +103,9 @@ func WithState[State any](state *SessionState[State]) StreamBidiOption[State] { func WithSnapshotID[State any](id string) StreamBidiOption[State] { return &streamBidiOptions[State]{snapshotID: id} } + +// WithPromptInput overrides the default prompt input for a prompt-backed session flow. +// Used with DefineSessionFlowFromPrompt to customize the prompt rendering per invocation. +func WithPromptInput[State any](input any) StreamBidiOption[State] { + return &streamBidiOptions[State]{promptInput: input} +} diff --git a/go/ai/x/prompt_session_flow_test.go b/go/ai/x/prompt_session_flow_test.go new file mode 100644 index 0000000000..76bbe7fee9 --- /dev/null +++ b/go/ai/x/prompt_session_flow_test.go @@ -0,0 +1,375 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import ( + "context" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/registry" +) + +// setupPromptTestRegistry creates a registry with an echo model and generate action. +func setupPromptTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + reg := registry.New() + ctx := context.Background() + + ai.ConfigureFormats(reg) + ai.DefineModel(reg, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Echo back the last user message text. + var text string + for i := len(req.Messages) - 1; i >= 0; i-- { + if req.Messages[i].Role == ai.RoleUser { + text = req.Messages[i].Text() + break + } + } + if text == "" { + text = "no input" + } + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("echo: " + text), + } + + if cb != nil { + if err := cb(ctx, &ai.ModelResponseChunk{ + Content: resp.Message.Content, + }); err != nil { + return nil, err + } + } + + return resp, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + return reg +} + +func TestPromptSessionFlow_Basic(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + prompt := ai.DefinePrompt(reg, "testPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + sf := DefineSessionFlowFromPrompt[testStatus, testState]( + reg, "promptFlow", prompt, nil, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + + var gotChunk bool + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + gotChunk = true + } + if chunk.EndTurn { + break + } + } + if !gotChunk { + t.Error("expected at least one streaming chunk") + } + + // Turn 2. + if err := conn.SendText("world"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 user messages + 2 model replies = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptSessionFlow_PromptInputOverride(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + type greetInput struct { + Name string `json:"name"` + } + + prompt := ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", + ai.WithModelName("test/echo"), + ai.WithPrompt("Hello {{name}}!"), + ) + + sf := DefineSessionFlowFromPrompt[testStatus, testState]( + reg, "promptInputFlow", prompt, greetInput{Name: "default"}, + ) + + // Use WithPromptInput to override. + conn, err := sf.StreamBidi(ctx, + WithPromptInput[testState](greetInput{Name: "override"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Verify the override was stored in session state. + if response.State.PromptInput == nil { + t.Fatal("expected PromptInput in state") + } + + // The model echoes the last user message, which is "hi". + // But the prompt was rendered with "override" so "Hello override!" should appear + // in the messages sent to the model (verified via the echo). + // We primarily verify the state was set correctly. + inputMap, ok := response.State.PromptInput.(map[string]any) + if !ok { + t.Fatalf("expected PromptInput to be map[string]any, got %T", response.State.PromptInput) + } + if name, _ := inputMap["name"].(string); name != "override" { + t.Errorf("expected PromptInput name='override', got %q", name) + } +} + +func TestPromptSessionFlow_MultiTurnHistory(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + // Use a model that echoes all message count so we can verify history grows. + ai.DefineModel(reg, "test/history", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Count total messages received (includes prompt-rendered + history). + var parts []string + for _, m := range req.Messages { + parts = append(parts, string(m.Role)+":"+m.Text()) + } + text := strings.Join(parts, "|") + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage(text), + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) + } + return resp, nil + }, + ) + + prompt := ai.DefinePrompt(reg, "historyPrompt", + ai.WithModelName("test/history"), + ai.WithSystem("system prompt"), + ) + + sf := DefineSessionFlowFromPrompt[testStatus, testState]( + reg, "historyFlow", prompt, nil, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + conn.SendText("turn1") + var turn1Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + turn1Response += chunk.Chunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 1 should have: system message + user message "turn1" (2 messages total from prompt + history). + // The system message comes from the prompt, "turn1" from session history. + if !strings.Contains(turn1Response, "turn1") { + t.Errorf("turn1 response should contain 'turn1', got: %s", turn1Response) + } + + // Turn 2. + conn.SendText("turn2") + var turn2Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + turn2Response += chunk.Chunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 2 should have: system + turn1 user + turn1 model reply + turn2 user (4 messages from prompt + history). + if !strings.Contains(turn2Response, "turn1") || !strings.Contains(turn2Response, "turn2") { + t.Errorf("turn2 response should contain both 'turn1' and 'turn2', got: %s", turn2Response) + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should have: turn1 user + turn1 model + turn2 user + turn2 model = 4 messages. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages in session, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptSessionFlow_SnapshotPersistsPromptInput(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + prompt := ai.DefinePrompt(reg, "snapPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + sf := DefineSessionFlowFromPrompt[testStatus]( + reg, "snapPromptFlow", prompt, nil, + WithSnapshotStore(store), + ) + + // Start with prompt input. + conn, err := sf.StreamBidi(ctx, + WithPromptInput[testState](map[string]any{"key": "value"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + resp, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + if resp.SnapshotID == "" { + t.Fatal("expected snapshot ID") + } + + // Verify the snapshot contains PromptInput. + snap, err := store.GetSnapshot(ctx, resp.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.State.PromptInput == nil { + t.Error("expected PromptInput in snapshot state") + } + + // Resume from snapshot — the PromptInput should be preserved. + conn2, err := sf.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) + if err != nil { + t.Fatalf("StreamBidi with snapshot failed: %v", err) + } + + conn2.SendText("continued") + for chunk, err := range conn2.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn2.Close() + + resp2, err := conn2.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have messages from both invocations. + if got := len(resp2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + + // PromptInput should still be present. + if resp2.State.PromptInput == nil { + t.Error("expected PromptInput preserved after resume") + } +} diff --git a/go/ai/x/session_flow.go b/go/ai/x/session_flow.go index 29c0648549..ce83953703 100644 --- a/go/ai/x/session_flow.go +++ b/go/ai/x/session_flow.go @@ -64,6 +64,9 @@ type SessionFlowInit[State any] struct { // State provides direct state for the invocation. // Mutually exclusive with SnapshotID. State *SessionState[State] `json:"state,omitempty"` + // PromptInput overrides the default prompt input for this invocation. + // Used by prompt-backed session flows (DefineSessionFlowFromPrompt). + PromptInput any `json:"promptInput,omitempty"` } // SessionFlowOutput is the output when a session flow invocation completes. @@ -200,6 +203,13 @@ func (s *Session[State]) UpdateCustom(fn func(State) State) { s.state.Custom = fn(s.state.Custom) } +// PromptInput returns the prompt input stored in the session state. +func (s *Session[State]) PromptInput() any { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.PromptInput +} + // Artifacts returns the current artifacts. func (s *Session[State]) Artifacts() []*SessionFlowArtifact { s.mu.RLock() @@ -451,8 +461,9 @@ func (sf *SessionFlow[Stream, State]) StreamBidi( } init := &SessionFlowInit[State]{ - SnapshotID: sbOpts.snapshotID, - State: sbOpts.state, + SnapshotID: sbOpts.snapshotID, + State: sbOpts.state, + PromptInput: sbOpts.promptInput, } conn, err := sf.flow.StreamBidi(ctx, init) @@ -496,7 +507,7 @@ func (sf *SessionFlow[Stream, State]) runWrapped( // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. // Writes through respCh to preserve ordering with user chunks. session.onEndTurn = func(turnCtx context.Context) { - snapshotID := session.maybeSnapshot(turnCtx, TurnEnd) + snapshotID := session.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) if snapshotID != "" { respCh <- &SessionFlowStreamChunk[Stream]{SnapshotCreated: snapshotID} } @@ -516,7 +527,7 @@ func (sf *SessionFlow[Stream, State]) runWrapped( } // Final snapshot at invocation end. - snapshotID := session.maybeSnapshot(ctx, InvocationEnd) + snapshotID := session.maybeSnapshot(ctx, SnapshotEventInvocationEnd) return &SessionFlowOutput[State]{ State: session.State(), @@ -551,6 +562,9 @@ func newSessionFromInit[State any]( } else if init.State != nil { s.state = *init.State } + if init.PromptInput != nil { + s.state.PromptInput = init.PromptInput + } } return s, nil @@ -658,3 +672,74 @@ func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[Stat func (c *SessionFlowConnection[Stream, State]) Done() <-chan struct{} { return c.conn.Done() } + +// --- Prompt-backed SessionFlow --- + +// PromptRenderer renders a prompt with typed input into GenerateActionOptions. +// This interface is satisfied by both ai.Prompt (with In=any) and +// *ai.DataPrompt[In, Out]. +type PromptRenderer[In any] interface { + Render(ctx context.Context, input In) (*ai.GenerateActionOptions, error) +} + +// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an +// automatic conversation loop. Each turn renders the prompt, appends +// conversation history, calls GenerateWithRequest, streams chunks to the +// client, and adds the model response to the session. +// +// The defaultInput is used for prompt rendering unless overridden per +// invocation via WithPromptInput. +func DefineSessionFlowFromPrompt[Stream, State, PromptIn any]( + r api.Registry, + name string, + p PromptRenderer[PromptIn], + defaultInput PromptIn, + opts ...SessionFlowOption[State], +) *SessionFlow[Stream, State] { + fn := func(ctx context.Context, resp Responder[Stream], params *SessionFlowParams[State]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + + // Resolve prompt input: session state override > default. + var promptInput PromptIn + if stored := sess.PromptInput(); stored != nil { + typed, ok := stored.(PromptIn) + if !ok { + return fmt.Errorf("prompt input type mismatch: got %T, want %T", stored, promptInput) + } + promptInput = typed + } else { + promptInput = defaultInput + } + + // Render the prompt template. + actionOpts, err := p.Render(ctx, promptInput) + if err != nil { + return fmt.Errorf("prompt render: %w", err) + } + + // Append conversation history after the prompt-rendered messages. + actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) + + // Call the model with streaming. + modelResp, err := ai.GenerateWithRequest(ctx, r, actionOpts, nil, + func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + resp.SendChunk(chunk) + return nil + }, + ) + if err != nil { + return fmt.Errorf("generate: %w", err) + } + + // Add the model response message to session history. + if modelResp.Message != nil { + sess.AddMessages(modelResp.Message) + } + + return nil + }) + } + + return DefineSessionFlow(r, name, fn, opts...) +} diff --git a/go/ai/x/snapshot.go b/go/ai/x/snapshot.go index c81bfcff81..e9411aa32f 100644 --- a/go/ai/x/snapshot.go +++ b/go/ai/x/snapshot.go @@ -35,6 +35,9 @@ type SessionState[State any] struct { Custom State `json:"custom,omitempty"` // Artifacts are named collections of parts produced during the conversation. Artifacts []*SessionFlowArtifact `json:"artifacts,omitempty"` + // PromptInput is the input used for prompt rendering in prompt-backed session flows. + // Stored as any to support type-erased persistence across snapshot boundaries. + PromptInput any `json:"promptInput,omitempty"` } // SnapshotEvent identifies what triggered a snapshot. @@ -42,9 +45,9 @@ type SnapshotEvent string const ( // TurnEnd indicates the snapshot was triggered at the end of a turn. - TurnEnd SnapshotEvent = "turnEnd" + SnapshotEventTurnEnd SnapshotEvent = "turnEnd" // InvocationEnd indicates the snapshot was triggered at the end of the invocation. - InvocationEnd SnapshotEvent = "invocationEnd" + SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" ) // SessionSnapshot is a persisted point-in-time capture of session state. diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 94a717ba98..b873efc6e2 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -447,6 +447,27 @@ func DefineSessionFlow[Stream, State any]( return aix.DefineSessionFlow(g.reg, name, fn, opts...) } +// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an +// automatic conversation loop. Each turn renders the prompt, appends +// conversation history, calls the model with streaming, and updates session state. +// +// The defaultInput is used for prompt rendering unless overridden per +// invocation via [aix.WithPromptInput]. +// +// Type parameters: +// - In: The prompt input type +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +func DefineSessionFlowFromPrompt[Stream, State, PromptIn any]( + g *Genkit, + name string, + p aix.PromptRenderer[PromptIn], + defaultInput PromptIn, + opts ...aix.SessionFlowOption[State], +) *aix.SessionFlow[Stream, State] { + return aix.DefineSessionFlowFromPrompt[Stream](g.reg, name, p, defaultInput, opts...) +} + // Run executes the given function `fn` within the context of the current flow run, // creating a distinct trace span for this step. It's used to add observability // to specific sub-operations within a flow defined by [DefineFlow] or [DefineStreamingFlow]. diff --git a/go/samples/basic-session-flow/main.go b/go/samples/basic-session-flow/main.go index bc714ef409..50a940c253 100644 --- a/go/samples/basic-session-flow/main.go +++ b/go/samples/basic-session-flow/main.go @@ -64,7 +64,7 @@ func main() { }) }, aix.WithSnapshotStore(store), - aix.WithSnapshotCallback(aix.SnapshotOn[struct{}](aix.TurnEnd)), + aix.WithSnapshotCallback(aix.SnapshotOn[struct{}](aix.SnapshotEventTurnEnd)), ) fmt.Println("Session Flow Chat (type 'quit' to exit)") diff --git a/go/samples/prompt-session-flow/main.go b/go/samples/prompt-session-flow/main.go new file mode 100644 index 0000000000..d05b4b8929 --- /dev/null +++ b/go/samples/prompt-session-flow/main.go @@ -0,0 +1,101 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates DefineSessionFlowFromPrompt, which creates a +// multi-turn conversational session flow backed by a .prompt file. The +// conversation loop (render prompt, call model, stream chunks, update history) +// is handled automatically. Compare with basic-session-flow which wires +// the same loop manually. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + aix "github.com/firebase/genkit/go/ai/x" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +type ChatPromptInput struct { + Personality string `json:"personality"` +} + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + chatPrompt := genkit.LookupDataPrompt[ChatPromptInput, string](g, "chat") + + chatFlow := genkit.DefineSessionFlowFromPrompt[struct{}]( + g, "chat", chatPrompt, ChatPromptInput{Personality: "a sarcastic pirate"}, + aix.WithSnapshotStore(aix.NewInMemorySnapshotStore[struct{}]()), + aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[struct{}]) bool { + return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 + }), + ) + + fmt.Println("Prompt Session Flow Chat (type 'quit' to exit)") + fmt.Println() + + conn, err := chatFlow.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.Chunk != nil { + fmt.Print(chunk.Chunk.Text()) + } + if chunk.SnapshotCreated != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotCreated) + } + if chunk.EndTurn { + fmt.Println() + fmt.Println() + break + } + } + } + + conn.Close() +} diff --git a/go/samples/prompt-session-flow/prompts/chat.prompt b/go/samples/prompt-session-flow/prompts/chat.prompt new file mode 100644 index 0000000000..6a78a99b07 --- /dev/null +++ b/go/samples/prompt-session-flow/prompts/chat.prompt @@ -0,0 +1,12 @@ +--- +model: googleai/gemini-3-flash-preview +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: + personality: string + default: + personality: a helpful assistant +--- +You are {{personality}}. Keep responses concise. From e94bca1855381993bfed0ce3005378115442ad71 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 14:34:32 -0800 Subject: [PATCH 010/141] removed stream type param --- go/ai/x/prompt_session_flow_test.go | 8 ++++---- go/ai/x/session_flow.go | 10 ++++------ go/genkit/genkit.go | 9 ++++----- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/go/ai/x/prompt_session_flow_test.go b/go/ai/x/prompt_session_flow_test.go index 76bbe7fee9..70dade4e8c 100644 --- a/go/ai/x/prompt_session_flow_test.go +++ b/go/ai/x/prompt_session_flow_test.go @@ -74,7 +74,7 @@ func TestPromptSessionFlow_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - sf := DefineSessionFlowFromPrompt[testStatus, testState]( + sf := DefineSessionFlowFromPrompt[testState]( reg, "promptFlow", prompt, nil, ) @@ -146,7 +146,7 @@ func TestPromptSessionFlow_PromptInputOverride(t *testing.T) { ai.WithPrompt("Hello {{name}}!"), ) - sf := DefineSessionFlowFromPrompt[testStatus, testState]( + sf := DefineSessionFlowFromPrompt[testState]( reg, "promptInputFlow", prompt, greetInput{Name: "default"}, ) @@ -223,7 +223,7 @@ func TestPromptSessionFlow_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - sf := DefineSessionFlowFromPrompt[testStatus, testState]( + sf := DefineSessionFlowFromPrompt[testState]( reg, "historyFlow", prompt, nil, ) @@ -299,7 +299,7 @@ func TestPromptSessionFlow_SnapshotPersistsPromptInput(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - sf := DefineSessionFlowFromPrompt[testStatus]( + sf := DefineSessionFlowFromPrompt[testState]( reg, "snapPromptFlow", prompt, nil, WithSnapshotStore(store), ) diff --git a/go/ai/x/session_flow.go b/go/ai/x/session_flow.go index ce83953703..06621014c0 100644 --- a/go/ai/x/session_flow.go +++ b/go/ai/x/session_flow.go @@ -689,27 +689,25 @@ type PromptRenderer[In any] interface { // // The defaultInput is used for prompt rendering unless overridden per // invocation via WithPromptInput. -func DefineSessionFlowFromPrompt[Stream, State, PromptIn any]( +func DefineSessionFlowFromPrompt[State, PromptIn any]( r api.Registry, name string, p PromptRenderer[PromptIn], defaultInput PromptIn, opts ...SessionFlowOption[State], -) *SessionFlow[Stream, State] { - fn := func(ctx context.Context, resp Responder[Stream], params *SessionFlowParams[State]) error { +) *SessionFlow[struct{}, State] { + fn := func(ctx context.Context, resp Responder[struct{}], params *SessionFlowParams[State]) error { return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess := params.Session // Resolve prompt input: session state override > default. - var promptInput PromptIn + promptInput := defaultInput if stored := sess.PromptInput(); stored != nil { typed, ok := stored.(PromptIn) if !ok { return fmt.Errorf("prompt input type mismatch: got %T, want %T", stored, promptInput) } promptInput = typed - } else { - promptInput = defaultInput } // Render the prompt template. diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index b873efc6e2..78298c15af 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -455,17 +455,16 @@ func DefineSessionFlow[Stream, State any]( // invocation via [aix.WithPromptInput]. // // Type parameters: -// - In: The prompt input type -// - Stream: Type for status updates sent via the responder // - State: Type for user-defined state in snapshots -func DefineSessionFlowFromPrompt[Stream, State, PromptIn any]( +// - PromptIn: The prompt input type (inferred from the PromptRenderer) +func DefineSessionFlowFromPrompt[State, PromptIn any]( g *Genkit, name string, p aix.PromptRenderer[PromptIn], defaultInput PromptIn, opts ...aix.SessionFlowOption[State], -) *aix.SessionFlow[Stream, State] { - return aix.DefineSessionFlowFromPrompt[Stream](g.reg, name, p, defaultInput, opts...) +) *aix.SessionFlow[struct{}, State] { + return aix.DefineSessionFlowFromPrompt(g.reg, name, p, defaultInput, opts...) } // Run executes the given function `fn` within the context of the current flow run, From ec1aa79c201663721b15634010f3e3734233fef1 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 14:36:03 -0800 Subject: [PATCH 011/141] Update action.go --- go/core/action.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 9d5e4c31a8..125e44961f 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -531,13 +531,12 @@ type BidiConnection[In, Out, Stream any] struct { // Send sends an input message to the bidi action. // Returns an error if the connection is closed or the context is cancelled. -func (c *BidiConnection[In, Out, Stream]) Send(input In) error { - c.mu.Lock() - if c.closed { - c.mu.Unlock() - return NewError(FAILED_PRECONDITION, "connection is closed") - } - c.mu.Unlock() +func (c *BidiConnection[In, Out, Stream]) Send(input In) (err error) { + defer func() { + if r := recover(); r != nil { + err = NewError(FAILED_PRECONDITION, "connection is closed") + } + }() select { case c.inputCh <- input: From a77fa339a3e3dd9a33daaabddffc4b0627aa4f8c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 15:50:08 -0800 Subject: [PATCH 012/141] updates --- go/samples/basic-session-flow/main.go | 15 ++++++--------- go/samples/prompt-session-flow/main.go | 6 +++--- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/go/samples/basic-session-flow/main.go b/go/samples/basic-session-flow/main.go index 50a940c253..2e2fe6aa6e 100644 --- a/go/samples/basic-session-flow/main.go +++ b/go/samples/basic-session-flow/main.go @@ -35,12 +35,9 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - store := aix.NewInMemorySnapshotStore[struct{}]() - chatFlow := genkit.DefineSessionFlow(g, "chat", - func(ctx context.Context, resp aix.Responder[any], params *aix.SessionFlowParams[struct{}]) error { - sess := params.Session - return sess.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { + func(ctx context.Context, resp aix.Responder[any], params *aix.SessionFlowParams[any]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ @@ -48,13 +45,13 @@ func main() { }, })), ai.WithSystem("You are a helpful assistant. Keep responses concise."), - ai.WithMessages(sess.Messages()...), + ai.WithMessages(params.Session.Messages()...), ) { if err != nil { return err } if chunk.Done { - sess.AddMessages(chunk.Response.Message) + params.Session.AddMessages(chunk.Response.Message) break } resp.SendChunk(chunk.Chunk) @@ -63,8 +60,8 @@ func main() { return nil }) }, - aix.WithSnapshotStore(store), - aix.WithSnapshotCallback(aix.SnapshotOn[struct{}](aix.SnapshotEventTurnEnd)), + aix.WithSnapshotStore(aix.NewInMemorySnapshotStore[any]()), + aix.WithSnapshotCallback(aix.SnapshotOn[any](aix.SnapshotEventTurnEnd)), ) fmt.Println("Session Flow Chat (type 'quit' to exit)") diff --git a/go/samples/prompt-session-flow/main.go b/go/samples/prompt-session-flow/main.go index d05b4b8929..b85988c10d 100644 --- a/go/samples/prompt-session-flow/main.go +++ b/go/samples/prompt-session-flow/main.go @@ -41,10 +41,10 @@ func main() { chatPrompt := genkit.LookupDataPrompt[ChatPromptInput, string](g, "chat") - chatFlow := genkit.DefineSessionFlowFromPrompt[struct{}]( + chatFlow := genkit.DefineSessionFlowFromPrompt( g, "chat", chatPrompt, ChatPromptInput{Personality: "a sarcastic pirate"}, - aix.WithSnapshotStore(aix.NewInMemorySnapshotStore[struct{}]()), - aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[struct{}]) bool { + aix.WithSnapshotStore(aix.NewInMemorySnapshotStore[any]()), + aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 }), ) From e83af30117413a418ab3ec7d49352b1b3899003b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 15:13:02 -0800 Subject: [PATCH 013/141] cleaned up API naming and behavior --- go/ai/x/{session_flow.go => agent_flow.go} | 518 +++++++++--------- ...ession_flow_test.go => agent_flow_test.go} | 176 +++--- go/ai/x/option.go | 36 +- ...sion_flow_test.go => prompt_agent_test.go} | 30 +- go/ai/x/snapshot.go | 26 +- go/core/api/action.go | 4 +- go/genkit/genkit.go | 32 +- .../main.go | 20 +- .../main.go | 16 +- .../prompts/chat.prompt | 0 10 files changed, 422 insertions(+), 436 deletions(-) rename go/ai/x/{session_flow.go => agent_flow.go} (58%) rename go/ai/x/{session_flow_test.go => agent_flow_test.go} (73%) rename go/ai/x/{prompt_session_flow_test.go => prompt_agent_test.go} (92%) rename go/samples/{basic-session-flow => custom-agent}/main.go (79%) rename go/samples/{prompt-session-flow => prompt-agent}/main.go (81%) rename go/samples/{prompt-session-flow => prompt-agent}/prompts/chat.prompt (100%) diff --git a/go/ai/x/session_flow.go b/go/ai/x/agent_flow.go similarity index 58% rename from go/ai/x/session_flow.go rename to go/ai/x/agent_flow.go index 06621014c0..900cef135d 100644 --- a/go/ai/x/session_flow.go +++ b/go/ai/x/agent_flow.go @@ -38,9 +38,9 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" ) -// SessionFlowArtifact represents a named collection of parts produced during a session. +// AgentArtifact represents a named collection of parts produced during a session. // Examples: generated files, images, code snippets, diagrams, etc. -type SessionFlowArtifact struct { +type AgentArtifact struct { // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). Name string `json:"name,omitempty"` // Parts contains the artifact content (text, media, etc.). @@ -49,28 +49,28 @@ type SessionFlowArtifact struct { Metadata map[string]any `json:"metadata,omitempty"` } -// SessionFlowInput is the input sent to a session flow during a conversation turn. -type SessionFlowInput struct { +// AgentFlowInput is the input sent to an agent flow during a conversation turn. +type AgentFlowInput struct { // Messages contains the user's input for this turn. Messages []*ai.Message `json:"messages,omitempty"` } -// SessionFlowInit is the input for starting a session flow invocation. +// AgentFlowInit is the input for starting an agent flow invocation. // Provide either SnapshotID (to load from store) or State (direct state). -type SessionFlowInit[State any] struct { +type AgentFlowInit[State any] struct { // SnapshotID loads state from a persisted snapshot. // Mutually exclusive with State. SnapshotID string `json:"snapshotId,omitempty"` // State provides direct state for the invocation. // Mutually exclusive with SnapshotID. State *SessionState[State] `json:"state,omitempty"` - // PromptInput overrides the default prompt input for this invocation. - // Used by prompt-backed session flows (DefineSessionFlowFromPrompt). - PromptInput any `json:"promptInput,omitempty"` + // InputVariables overrides the default input variables for this invocation. + // Used by agent flows that require input variables (DefinePromptAgent). + InputVariables any `json:"inputVariables,omitempty"` } -// SessionFlowOutput is the output when a session flow invocation completes. -type SessionFlowOutput[State any] struct { +// AgentFlowOutput is the output when an agent flow invocation completes. +type AgentFlowOutput[State any] struct { // SnapshotID is the ID of the snapshot created at the end of this invocation. // Empty if no snapshot was created (callback returned false or no store configured). SnapshotID string `json:"snapshotId,omitempty"` @@ -78,79 +78,34 @@ type SessionFlowOutput[State any] struct { State *SessionState[State] `json:"state"` } -// SessionFlowStreamChunk represents a single item in the session flow's output stream. +// AgentFlowStreamChunk represents a single item in the agent flow's output stream. // Multiple fields can be populated in a single chunk. -type SessionFlowStreamChunk[Stream any] struct { +type AgentFlowStreamChunk[Stream any] struct { // Chunk contains token-level generation data. Chunk *ai.ModelResponseChunk `json:"chunk,omitempty"` // Status contains user-defined structured status information. // The Stream type parameter defines the shape of this data. Status Stream `json:"status,omitempty"` // Artifact contains a newly produced artifact. - Artifact *SessionFlowArtifact `json:"artifact,omitempty"` - // SnapshotCreated contains the ID of a snapshot that was just persisted. - SnapshotCreated string `json:"snapshotCreated,omitempty"` - // EndTurn signals that the session flow has finished processing the current input. + Artifact *AgentArtifact `json:"artifact,omitempty"` + // SnapshotID contains the ID of a snapshot that was just persisted. + SnapshotID string `json:"snapshotId,omitempty"` + // EndTurn signals that the agent flow has finished processing the current input. // When true, the client should stop iterating and may send the next input. EndTurn bool `json:"endTurn,omitempty"` } // --- Session --- -// Session holds the working state during a session flow invocation. -// It is propagated through context and provides read/write access to state. +// Session holds conversation state and provides thread-safe read/write access to messages, +// input variables, custom state, and artifacts. type Session[State any] struct { mu sync.RWMutex state SessionState[State] - store SnapshotStore[State] - - snapshotCallback SnapshotCallback[State] - - // onEndTurn is set by the framework; triggers snapshot + EndTurn chunk. - onEndTurn func(ctx context.Context) - inCh <-chan *SessionFlowInput - - // Snapshot tracking - lastSnapshot *SessionSnapshot[State] - turnIndex int -} - -// Run loops over the input channel, calling fn for each turn. Each turn is -// wrapped in a trace span for observability. Input messages are automatically -// added to the session before fn is called. After fn returns successfully, an -// EndTurn chunk is sent and a snapshot check is triggered. -func (s *Session[State]) Run( - ctx context.Context, - fn func(ctx context.Context, input *SessionFlowInput) error, -) error { - for input := range s.inCh { - spanMeta := &tracing.SpanMetadata{ - Name: fmt.Sprintf("sessionFlow/turn/%d", s.turnIndex), - Type: "sessionFlowTurn", - Subtype: "sessionFlowTurn", - } - - _, err := tracing.RunInNewSpan(ctx, spanMeta, input, - func(ctx context.Context, input *SessionFlowInput) (struct{}, error) { - s.AddMessages(input.Messages...) - - if err := fn(ctx, input); err != nil { - return struct{}{}, err - } - - s.onEndTurn(ctx) - s.turnIndex++ - return struct{}{}, nil - }, - ) - if err != nil { - return err - } - } - return nil + store SessionStore[State] } -// State returns a copy of the current session flow state. +// State returns a copy of the current state. func (s *Session[State]) State() *SessionState[State] { s.mu.RLock() defer s.mu.RUnlock() @@ -211,17 +166,17 @@ func (s *Session[State]) PromptInput() any { } // Artifacts returns the current artifacts. -func (s *Session[State]) Artifacts() []*SessionFlowArtifact { +func (s *Session[State]) Artifacts() []*AgentArtifact { s.mu.RLock() defer s.mu.RUnlock() - arts := make([]*SessionFlowArtifact, len(s.state.Artifacts)) + arts := make([]*AgentArtifact, len(s.state.Artifacts)) copy(arts, s.state.Artifacts) return arts } // AddArtifacts adds artifacts to the session. If an artifact with the same // name already exists, it is replaced. -func (s *Session[State]) AddArtifacts(artifacts ...*SessionFlowArtifact) { +func (s *Session[State]) AddArtifacts(artifacts ...*AgentArtifact) { s.mu.Lock() defer s.mu.Unlock() for _, a := range artifacts { @@ -242,70 +197,151 @@ func (s *Session[State]) AddArtifacts(artifacts ...*SessionFlowArtifact) { } // SetArtifacts replaces the entire artifact list. -func (s *Session[State]) SetArtifacts(artifacts []*SessionFlowArtifact) { +func (s *Session[State]) SetArtifacts(artifacts []*AgentArtifact) { s.mu.Lock() defer s.mu.Unlock() s.state.Artifacts = artifacts } +// copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). +func (s *Session[State]) copyStateLocked() SessionState[State] { + bytes, err := json.Marshal(s.state) + if err != nil { + panic(fmt.Sprintf("agent flow: failed to marshal state: %v", err)) + } + var copied SessionState[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + panic(fmt.Sprintf("agent flow: failed to unmarshal state: %v", err)) + } + return copied +} + +// --- Session context --- + +type sessionContextKey struct{} + +type sessionHolder struct { + session any +} + +// NewSessionContext returns a new context with the session attached. +func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { + return context.WithValue(ctx, sessionContextKey{}, &sessionHolder{session: s}) +} + +// SessionFromContext retrieves the current session from context. +// Returns nil if no session is in context or if the type doesn't match. +func SessionFromContext[State any](ctx context.Context) *Session[State] { + holder, ok := ctx.Value(sessionContextKey{}).(*sessionHolder) + if !ok || holder == nil { + return nil + } + session, ok := holder.session.(*Session[State]) + if !ok { + return nil + } + return session +} + +// --- AgentSession --- + +// AgentSession extends Session with agent-flow-specific functionality: +// turn management, snapshot persistence, and input channel handling. +type AgentSession[State any] struct { + *Session[State] + snapshotCallback SnapshotCallback[State] + onEndTurn func(ctx context.Context) + inCh <-chan *AgentFlowInput + lastSnapshot *SessionSnapshot[State] + turnIndex int +} + +// Run loops over the input channel, calling fn for each turn. Each turn is +// wrapped in a trace span for observability. Input messages are automatically +// added to the session before fn is called. After fn returns successfully, an +// EndTurn chunk is sent and a snapshot check is triggered. +func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentFlowInput) error) error { + for input := range a.inCh { + spanMeta := &tracing.SpanMetadata{ + Name: fmt.Sprintf("agentFlow/turn/%d", a.turnIndex), + Type: "agentFlowTurn", + Subtype: "agentFlowTurn", + } + + _, err := tracing.RunInNewSpan(ctx, spanMeta, input, + func(ctx context.Context, input *AgentFlowInput) (struct{}, error) { + a.AddMessages(input.Messages...) + + if err := fn(ctx, input); err != nil { + return struct{}{}, err + } + + a.onEndTurn(ctx) + a.turnIndex++ + return struct{}{}, nil + }, + ) + if err != nil { + return err + } + } + return nil +} + // maybeSnapshot creates a snapshot if conditions are met (store configured, // callback approves). Returns the snapshot ID or empty string. -func (s *Session[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { - if s.store == nil { +func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { + if a.store == nil { return "" } - s.mu.RLock() - currentState := s.copyStateLocked() - turnIndex := s.turnIndex - s.mu.RUnlock() + a.mu.RLock() + currentState := a.copyStateLocked() + a.mu.RUnlock() - shouldSnapshot := true - if s.snapshotCallback != nil { + if a.snapshotCallback != nil { var prevState *SessionState[State] - if s.lastSnapshot != nil { - prevState = &s.lastSnapshot.State + if a.lastSnapshot != nil { + prevState = &a.lastSnapshot.State } - shouldSnapshot = s.snapshotCallback(ctx, &SnapshotContext[State]{ + if !a.snapshotCallback(ctx, &SnapshotContext[State]{ State: ¤tState, PrevState: prevState, - TurnIndex: turnIndex, + TurnIndex: a.turnIndex, Event: event, - }) - } - - if !shouldSnapshot { - return "" + }) { + return "" + } } snapshot := &SessionSnapshot[State]{ SnapshotID: uuid.New().String(), CreatedAt: time.Now(), - TurnIndex: turnIndex, + TurnIndex: a.turnIndex, Event: event, State: currentState, } - if s.lastSnapshot != nil { - snapshot.ParentID = s.lastSnapshot.SnapshotID + if a.lastSnapshot != nil { + snapshot.ParentID = a.lastSnapshot.SnapshotID } - if err := s.store.SaveSnapshot(ctx, snapshot); err != nil { - slog.Error("session flow: failed to save snapshot", "err", err) + if err := a.store.SaveSnapshot(ctx, snapshot); err != nil { + slog.Error("agent flow: failed to save snapshot", "err", err) return "" } // Set snapshotId in last message metadata. - s.mu.Lock() - if msgs := s.state.Messages; len(msgs) > 0 { + a.mu.Lock() + if msgs := a.state.Messages; len(msgs) > 0 { lastMsg := msgs[len(msgs)-1] if lastMsg.Metadata == nil { lastMsg.Metadata = make(map[string]any) } lastMsg.Metadata["snapshotId"] = snapshot.SnapshotID } - s.mu.Unlock() + a.mu.Unlock() - s.lastSnapshot = snapshot + a.lastSnapshot = snapshot // Record on OTel span. span := oteltrace.SpanFromContext(ctx) @@ -316,182 +352,141 @@ func (s *Session[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) return snapshot.SnapshotID } -// copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). -func (s *Session[State]) copyStateLocked() SessionState[State] { - bytes, err := json.Marshal(s.state) - if err != nil { - panic(fmt.Sprintf("session flow: failed to marshal state: %v", err)) - } - var copied SessionState[State] - if err := json.Unmarshal(bytes, &copied); err != nil { - panic(fmt.Sprintf("session flow: failed to unmarshal state: %v", err)) - } - return copied -} - -// --- Session context --- - -type sessionContextKey struct{} - -type sessionHolder struct { - session any -} - -// NewSessionContext returns a new context with the session attached. -func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { - return context.WithValue(ctx, sessionContextKey{}, &sessionHolder{session: s}) -} - -// SessionFromContext retrieves the current session from context. -// Returns nil if no session is in context or if the type doesn't match. -func SessionFromContext[State any](ctx context.Context) *Session[State] { - holder, ok := ctx.Value(sessionContextKey{}).(*sessionHolder) - if !ok || holder == nil { - return nil - } - session, ok := holder.session.(*Session[State]) - if !ok { - return nil - } - return session -} - // --- Responder --- -// Responder is the output channel for a session flow. Chunks sent through it -// are automatically inspected: if a chunk contains an artifact, it is added to -// the session before being forwarded to the client. -// -// Convenience methods are provided for common chunk types. -type Responder[Stream any] chan<- *SessionFlowStreamChunk[Stream] +// Responder is the output channel for an agent flow. Artifacts sent through +// it are automatically added to the session before being forwarded to the +// client. +type Responder[Stream any] chan<- *AgentFlowStreamChunk[Stream] // SendChunk sends a generation chunk (token-level streaming). func (r Responder[Stream]) SendChunk(chunk *ai.ModelResponseChunk) { - r <- &SessionFlowStreamChunk[Stream]{Chunk: chunk} + r <- &AgentFlowStreamChunk[Stream]{Chunk: chunk} } // SendStatus sends a user-defined status update. func (r Responder[Stream]) SendStatus(status Stream) { - r <- &SessionFlowStreamChunk[Stream]{Status: status} + r <- &AgentFlowStreamChunk[Stream]{Status: status} } // SendArtifact sends an artifact to the stream and adds it to the session. // If an artifact with the same name already exists in the session, it is replaced. -func (r Responder[Stream]) SendArtifact(artifact *SessionFlowArtifact) { - r <- &SessionFlowStreamChunk[Stream]{Artifact: artifact} -} - -// --- SessionFlowParams --- - -// SessionFlowParams contains the parameters passed to a session flow function. -type SessionFlowParams[State any] struct { - // Session provides access to the working state. - Session *Session[State] +func (r Responder[Stream]) SendArtifact(artifact *AgentArtifact) { + r <- &AgentFlowStreamChunk[Stream]{Artifact: artifact} } -// --- SessionFlowFunc --- +// --- AgentFlowFunc --- -// SessionFlowFunc is the function signature for session flows. +// AgentFlowFunc is the function signature for agent flows. // Type parameters: // - Stream: Type for status updates sent via the responder // - State: Type for user-defined state in snapshots -type SessionFlowFunc[Stream, State any] func( +type AgentFlowFunc[Stream, State any] func( ctx context.Context, resp Responder[Stream], - params *SessionFlowParams[State], + sess *AgentSession[State], ) error -// --- SessionFlow --- +// --- AgentFlow --- -// SessionFlow is a bidirectional streaming flow with automatic snapshot management. -type SessionFlow[Stream, State any] struct { - flow *core.Flow[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream], *SessionFlowInit[State]] - store SnapshotStore[State] +// AgentFlow is a bidirectional streaming flow with automatic snapshot management. +type AgentFlow[Stream, State any] struct { + flow *core.Flow[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream], *AgentFlowInit[State]] + store SessionStore[State] snapshotCallback SnapshotCallback[State] } -// DefineSessionFlow creates a SessionFlow with automatic snapshot management and registers it. -func DefineSessionFlow[Stream, State any]( +// DefineCustomAgent creates an AgentFlow with automatic snapshot management and registers it. +func DefineCustomAgent[Stream, State any]( r api.Registry, name string, - fn SessionFlowFunc[Stream, State], - opts ...SessionFlowOption[State], -) *SessionFlow[Stream, State] { - sfOpts := &sessionFlowOptions[State]{} + fn AgentFlowFunc[Stream, State], + opts ...AgentFlowOption[State], +) *AgentFlow[Stream, State] { + afOpts := &agentFlowOptions[State]{} for _, opt := range opts { - if err := opt.applySessionFlow(sfOpts); err != nil { - panic(fmt.Errorf("DefineSessionFlow %q: %w", name, err)) + if err := opt.applyAgentFlow(afOpts); err != nil { + panic(fmt.Errorf("DefineCustomAgent %q: %w", name, err)) } } - sf := &SessionFlow[Stream, State]{ - store: sfOpts.store, - snapshotCallback: sfOpts.callback, + af := &AgentFlow[Stream, State]{ + store: afOpts.store, + snapshotCallback: afOpts.callback, } bidiFn := func( ctx context.Context, - init *SessionFlowInit[State], - inCh <-chan *SessionFlowInput, - outCh chan<- *SessionFlowStreamChunk[Stream], - ) (*SessionFlowOutput[State], error) { - return sf.runWrapped(ctx, init, inCh, outCh, fn) + init *AgentFlowInit[State], + inCh <-chan *AgentFlowInput, + outCh chan<- *AgentFlowStreamChunk[Stream], + ) (*AgentFlowOutput[State], error) { + return af.runWrapped(ctx, init, inCh, outCh, fn) } - sf.flow = core.DefineBidiFlow(r, name, bidiFn) + af.flow = core.DefineBidiFlow(r, name, bidiFn) // Register snapshot store action for reflection API. - if sfOpts.store != nil { - registerSnapshotStoreAction(r, name, sfOpts.store) + if afOpts.store != nil { + registerSessionStoreAction(r, name, afOpts.store) } - return sf + return af } -// StreamBidi starts a new session flow invocation. -func (sf *SessionFlow[Stream, State]) StreamBidi( +// StreamBidi starts a new agent flow invocation. +func (af *AgentFlow[Stream, State]) StreamBidi( ctx context.Context, opts ...StreamBidiOption[State], -) (*SessionFlowConnection[Stream, State], error) { +) (*AgentFlowConnection[Stream, State], error) { sbOpts := &streamBidiOptions[State]{} for _, opt := range opts { if err := opt.applyStreamBidi(sbOpts); err != nil { - return nil, fmt.Errorf("SessionFlow.StreamBidi %q: %w", sf.flow.Name(), err) + return nil, fmt.Errorf("AgentFlow.StreamBidi %q: %w", af.flow.Name(), err) } } - init := &SessionFlowInit[State]{ - SnapshotID: sbOpts.snapshotID, - State: sbOpts.state, - PromptInput: sbOpts.promptInput, + init := &AgentFlowInit[State]{ + SnapshotID: sbOpts.snapshotID, + State: sbOpts.state, + InputVariables: sbOpts.promptInput, } - conn, err := sf.flow.StreamBidi(ctx, init) + conn, err := af.flow.StreamBidi(ctx, init) if err != nil { return nil, err } - return &SessionFlowConnection[Stream, State]{conn: conn}, nil + return &AgentFlowConnection[Stream, State]{conn: conn}, nil } // runWrapped is the BidiFunc implementation. It sets up the session, // responder, and wiring, then delegates to the user's function. -func (sf *SessionFlow[Stream, State]) runWrapped( +func (af *AgentFlow[Stream, State]) runWrapped( ctx context.Context, - init *SessionFlowInit[State], - inCh <-chan *SessionFlowInput, - outCh chan<- *SessionFlowStreamChunk[Stream], - fn SessionFlowFunc[Stream, State], -) (*SessionFlowOutput[State], error) { - session, err := newSessionFromInit(ctx, init, sf.store, sf.snapshotCallback) + init *AgentFlowInit[State], + inCh <-chan *AgentFlowInput, + outCh chan<- *AgentFlowStreamChunk[Stream], + fn AgentFlowFunc[Stream, State], +) (*AgentFlowOutput[State], error) { + session, snapshot, err := newSessionFromInit(ctx, init, af.store) if err != nil { return nil, err } - session.inCh = inCh ctx = NewSessionContext(ctx, session) + agentSess := &AgentSession[State]{ + Session: session, + snapshotCallback: af.snapshotCallback, + inCh: inCh, + lastSnapshot: snapshot, + } + if snapshot != nil { + agentSess.turnIndex = snapshot.TurnIndex + } + // Intermediary channel: intercepts artifacts before forwarding to outCh. - respCh := make(chan *SessionFlowStreamChunk[Stream]) + respCh := make(chan *AgentFlowStreamChunk[Stream]) var wg sync.WaitGroup wg.Add(1) go func() { @@ -506,19 +501,15 @@ func (sf *SessionFlow[Stream, State]) runWrapped( // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. // Writes through respCh to preserve ordering with user chunks. - session.onEndTurn = func(turnCtx context.Context) { - snapshotID := session.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) + agentSess.onEndTurn = func(turnCtx context.Context) { + snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) if snapshotID != "" { - respCh <- &SessionFlowStreamChunk[Stream]{SnapshotCreated: snapshotID} + respCh <- &AgentFlowStreamChunk[Stream]{SnapshotID: snapshotID} } - respCh <- &SessionFlowStreamChunk[Stream]{EndTurn: true} + respCh <- &AgentFlowStreamChunk[Stream]{EndTurn: true} } - params := &SessionFlowParams[State]{ - Session: session, - } - - fnErr := fn(ctx, Responder[Stream](respCh), params) + fnErr := fn(ctx, Responder[Stream](respCh), agentSess) close(respCh) wg.Wait() @@ -527,47 +518,47 @@ func (sf *SessionFlow[Stream, State]) runWrapped( } // Final snapshot at invocation end. - snapshotID := session.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + snapshotID := agentSess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) - return &SessionFlowOutput[State]{ + return &AgentFlowOutput[State]{ State: session.State(), SnapshotID: snapshotID, }, nil } -// newSessionFromInit creates a session from initialization data. +// newSessionFromInit creates a Session from initialization data. +// If resuming from a snapshot, the loaded snapshot is also returned. func newSessionFromInit[State any]( ctx context.Context, - init *SessionFlowInit[State], - store SnapshotStore[State], - cb SnapshotCallback[State], -) (*Session[State], error) { - s := &Session[State]{ - store: store, - snapshotCallback: cb, - } + init *AgentFlowInit[State], + store SessionStore[State], +) (*Session[State], *SessionSnapshot[State], error) { + s := &Session[State]{store: store} + var snapshot *SessionSnapshot[State] if init != nil { + if init.SnapshotID != "" && store == nil { + return nil, nil, core.NewError(core.FAILED_PRECONDITION, "snapshot ID %q provided but no session store configured", init.SnapshotID) + } if init.SnapshotID != "" && store != nil { - snapshot, err := store.GetSnapshot(ctx, init.SnapshotID) + var err error + snapshot, err = store.GetSnapshot(ctx, init.SnapshotID) if err != nil { - return nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) + return nil, nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) } if snapshot == nil { - return nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) + return nil, nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) } s.state = snapshot.State - s.lastSnapshot = snapshot - s.turnIndex = snapshot.TurnIndex } else if init.State != nil { s.state = *init.State } - if init.PromptInput != nil { - s.state.PromptInput = init.PromptInput + if init.InputVariables != nil { + s.state.PromptInput = init.InputVariables } } - return s, nil + return s, snapshot, nil } // --- Snapshot store reflection action --- @@ -576,25 +567,25 @@ type getSnapshotInput struct { SnapshotID string `json:"snapshotId"` } -func registerSnapshotStoreAction[State any](r api.Registry, flowName string, store SnapshotStore[State]) { - core.DefineAction(r, flowName+"/getSnapshot", api.ActionTypeSnapshotStore, nil, nil, +func registerSessionStoreAction[State any](r api.Registry, flowName string, store SessionStore[State]) { + core.DefineAction(r, flowName+"/getSnapshot", api.ActionTypeSessionStore, nil, nil, func(ctx context.Context, input getSnapshotInput) (*SessionSnapshot[State], error) { return store.GetSnapshot(ctx, input.SnapshotID) }, ) } -// --- SessionFlowConnection --- +// --- AgentFlowConnection --- -// SessionFlowConnection wraps BidiConnection with session flow-specific functionality. +// AgentFlowConnection wraps BidiConnection with agent flow-specific functionality. // It provides a Receive() iterator that supports multi-turn patterns: breaking out // of the iterator between turns does not cancel the underlying connection. -type SessionFlowConnection[Stream, State any] struct { - conn *core.BidiConnection[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream]] +type AgentFlowConnection[Stream, State any] struct { + conn *core.BidiConnection[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream]] // chunks buffers stream chunks from the underlying connection so that // breaking from Receive() between turns doesn't cancel the context. - chunks chan *SessionFlowStreamChunk[Stream] + chunks chan *AgentFlowStreamChunk[Stream] chunkErr error initOnce sync.Once } @@ -602,9 +593,9 @@ type SessionFlowConnection[Stream, State any] struct { // initReceiver starts a goroutine that drains the underlying BidiConnection's // Receive into a channel. This goroutine never breaks from the underlying // iterator, preventing context cancellation. -func (c *SessionFlowConnection[Stream, State]) initReceiver() { +func (c *AgentFlowConnection[Stream, State]) initReceiver() { c.initOnce.Do(func() { - c.chunks = make(chan *SessionFlowStreamChunk[Stream], 1) + c.chunks = make(chan *AgentFlowStreamChunk[Stream], 1) go func() { defer close(c.chunks) for chunk, err := range c.conn.Receive() { @@ -618,25 +609,25 @@ func (c *SessionFlowConnection[Stream, State]) initReceiver() { }) } -// Send sends a SessionFlowInput to the session flow. -func (c *SessionFlowConnection[Stream, State]) Send(input *SessionFlowInput) error { +// Send sends an AgentFlowInput to the agent flow. +func (c *AgentFlowConnection[Stream, State]) Send(input *AgentFlowInput) error { return c.conn.Send(input) } -// SendMessages sends messages to the session flow. -func (c *SessionFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { - return c.conn.Send(&SessionFlowInput{Messages: messages}) +// SendMessages sends messages to the agent flow. +func (c *AgentFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { + return c.conn.Send(&AgentFlowInput{Messages: messages}) } -// SendText sends a single user text message to the session flow. -func (c *SessionFlowConnection[Stream, State]) SendText(text string) error { - return c.conn.Send(&SessionFlowInput{ +// SendText sends a single user text message to the agent flow. +func (c *AgentFlowConnection[Stream, State]) SendText(text string) error { + return c.conn.Send(&AgentFlowInput{ Messages: []*ai.Message{ai.NewUserTextMessage(text)}, }) } // Close signals that no more inputs will be sent. -func (c *SessionFlowConnection[Stream, State]) Close() error { +func (c *AgentFlowConnection[Stream, State]) Close() error { return c.conn.Close() } @@ -644,15 +635,14 @@ func (c *SessionFlowConnection[Stream, State]) Close() error { // Unlike the underlying BidiConnection.Receive, breaking out of this iterator // does not cancel the connection. This enables multi-turn patterns where the // caller breaks on EndTurn, sends the next input, then calls Receive again. -func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { +func (c *AgentFlowConnection[Stream, State]) Receive() iter.Seq2[*AgentFlowStreamChunk[Stream], error] { c.initReceiver() - return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { + return func(yield func(*AgentFlowStreamChunk[Stream], error) bool) { for { chunk, ok := <-c.chunks if !ok { if err := c.chunkErr; err != nil { - var zero *SessionFlowStreamChunk[Stream] - yield(zero, err) + yield(nil, err) } return } @@ -663,17 +653,17 @@ func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowS } } -// Output returns the final response after the session flow completes. -func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[State], error) { +// Output returns the final response after the agent flow completes. +func (c *AgentFlowConnection[Stream, State]) Output() (*AgentFlowOutput[State], error) { return c.conn.Output() } // Done returns a channel closed when the connection completes. -func (c *SessionFlowConnection[Stream, State]) Done() <-chan struct{} { +func (c *AgentFlowConnection[Stream, State]) Done() <-chan struct{} { return c.conn.Done() } -// --- Prompt-backed SessionFlow --- +// --- Prompt-backed AgentFlow --- // PromptRenderer renders a prompt with typed input into GenerateActionOptions. // This interface is satisfied by both ai.Prompt (with In=any) and @@ -682,24 +672,22 @@ type PromptRenderer[In any] interface { Render(ctx context.Context, input In) (*ai.GenerateActionOptions, error) } -// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an +// DefinePromptAgent creates a prompt-backed AgentFlow with an // automatic conversation loop. Each turn renders the prompt, appends // conversation history, calls GenerateWithRequest, streams chunks to the // client, and adds the model response to the session. // // The defaultInput is used for prompt rendering unless overridden per // invocation via WithPromptInput. -func DefineSessionFlowFromPrompt[State, PromptIn any]( +func DefinePromptAgent[State, PromptIn any]( r api.Registry, name string, p PromptRenderer[PromptIn], defaultInput PromptIn, - opts ...SessionFlowOption[State], -) *SessionFlow[struct{}, State] { - fn := func(ctx context.Context, resp Responder[struct{}], params *SessionFlowParams[State]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session - + opts ...AgentFlowOption[State], +) *AgentFlow[struct{}, State] { + fn := func(ctx context.Context, resp Responder[struct{}], sess *AgentSession[State]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Resolve prompt input: session state override > default. promptInput := defaultInput if stored := sess.PromptInput(); stored != nil { @@ -739,5 +727,5 @@ func DefineSessionFlowFromPrompt[State, PromptIn any]( }) } - return DefineSessionFlow(r, name, fn, opts...) + return DefineCustomAgent(r, name, fn, opts...) } diff --git a/go/ai/x/session_flow_test.go b/go/ai/x/agent_flow_test.go similarity index 73% rename from go/ai/x/session_flow_test.go rename to go/ai/x/agent_flow_test.go index d5722863fd..8e877a5820 100644 --- a/go/ai/x/session_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -39,14 +39,13 @@ func newTestRegistry(t *testing.T) *registry.Registry { return registry.New() } -func TestSessionFlow_BasicMultiTurn(t *testing.T) { +func TestAgentFlow_BasicMultiTurn(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - sf := DefineSessionFlow(reg, "basicFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "basicFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. if len(input.Messages) > 0 { @@ -63,7 +62,7 @@ func TestSessionFlow_BasicMultiTurn(t *testing.T) { }, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -115,15 +114,14 @@ func TestSessionFlow_BasicMultiTurn(t *testing.T) { } } -func TestSessionFlow_WithSnapshotStore(t *testing.T) { +func TestAgentFlow_WithSessionStore(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySnapshotStore[testState]() + store := NewInMemorySessionStore[testState]() - sf := DefineSessionFlow(reg, "snapshotFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "snapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -134,10 +132,10 @@ func TestSessionFlow_WithSnapshotStore(t *testing.T) { return nil }) }, - WithSnapshotStore[testState](store), + WithSessionStore[testState](store), ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -149,8 +147,8 @@ func TestSessionFlow_WithSnapshotStore(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.SnapshotCreated != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotCreated) + if chunk.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotID) } if chunk.EndTurn { break @@ -189,15 +187,14 @@ func TestSessionFlow_WithSnapshotStore(t *testing.T) { } } -func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { +func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySnapshotStore[testState]() + store := NewInMemorySessionStore[testState]() - sf := DefineSessionFlow(reg, "resumeFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "resumeFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -208,11 +205,11 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { return nil }) }, - WithSnapshotStore[testState](store), + WithSessionStore[testState](store), ) // First invocation: create a snapshot. - conn1, err := sf.StreamBidi(ctx) + conn1, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -235,7 +232,7 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { } // Second invocation: resume from snapshot. - conn2, err := sf.StreamBidi(ctx, WithSnapshotID[testState](resp1.SnapshotID)) + conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp1.SnapshotID)) if err != nil { t.Fatalf("StreamBidi with snapshot failed: %v", err) } @@ -280,14 +277,13 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { } } -func TestSessionFlow_ClientManagedState(t *testing.T) { +func TestAgentFlow_ClientManagedState(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - sf := DefineSessionFlow(reg, "clientStateFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "clientStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -309,7 +305,7 @@ func TestSessionFlow_ClientManagedState(t *testing.T) { Custom: testState{Counter: 5}, } - conn, err := sf.StreamBidi(ctx, WithState(clientState)) + conn, err := af.StreamBidi(ctx, WithState(clientState)) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -344,28 +340,27 @@ func TestSessionFlow_ClientManagedState(t *testing.T) { } } -func TestSessionFlow_Artifacts(t *testing.T) { +func TestAgentFlow_Artifacts(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - sf := DefineSessionFlow(reg, "artifactFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "artifactFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { - resp.SendArtifact(&SessionFlowArtifact{ + resp.SendArtifact(&AgentArtifact{ Name: "code.go", Parts: []*ai.Part{ai.NewTextPart("package main")}, }) // Replace artifact with same name. - resp.SendArtifact(&SessionFlowArtifact{ + resp.SendArtifact(&AgentArtifact{ Name: "code.go", Parts: []*ai.Part{ai.NewTextPart("package main\nfunc main() {}")}, }) // Add another artifact. - resp.SendArtifact(&SessionFlowArtifact{ + resp.SendArtifact(&AgentArtifact{ Name: "readme.md", Parts: []*ai.Part{ai.NewTextPart("# README")}, }) @@ -376,13 +371,13 @@ func TestSessionFlow_Artifacts(t *testing.T) { }, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } conn.SendText("generate code") - var receivedArtifacts []*SessionFlowArtifact + var receivedArtifacts []*AgentArtifact for chunk, err := range conn.Receive() { if err != nil { t.Fatalf("Receive error: %v", err) @@ -411,17 +406,16 @@ func TestSessionFlow_Artifacts(t *testing.T) { } } -func TestSessionFlow_SnapshotCallback(t *testing.T) { +func TestAgentFlow_SnapshotCallback(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySnapshotStore[testState]() + store := NewInMemorySessionStore[testState]() // Only snapshot on even turns. callbackCalls := 0 - sf := DefineSessionFlow(reg, "callbackFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "callbackFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -430,14 +424,14 @@ func TestSessionFlow_SnapshotCallback(t *testing.T) { return nil }) }, - WithSnapshotStore[testState](store), + WithSessionStore[testState](store), WithSnapshotCallback(func(ctx context.Context, sc *SnapshotContext[testState]) bool { callbackCalls++ return sc.TurnIndex%2 == 0 // only snapshot on even turns }), ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -449,8 +443,8 @@ func TestSessionFlow_SnapshotCallback(t *testing.T) { if err != nil { t.Fatalf("Receive error on turn %d: %v", i, err) } - if chunk.SnapshotCreated != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotCreated) + if chunk.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotID) } if chunk.EndTurn { break @@ -471,19 +465,19 @@ func TestSessionFlow_SnapshotCallback(t *testing.T) { } } -func TestSessionFlow_SendMessages(t *testing.T) { +func TestAgentFlow_SendMessages(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - sf := DefineSessionFlow(reg, "sendMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "sendMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { return nil }) }, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -517,31 +511,31 @@ func TestSessionFlow_SendMessages(t *testing.T) { } } -func TestSessionFlow_SessionContext(t *testing.T) { +func TestAgentFlow_SessionContext(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) var retrievedCounter int - sf := DefineSessionFlow(reg, "contextFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "contextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Session should be retrievable from context. - sess := SessionFromContext[testState](ctx) - if sess == nil { + ctxSess := SessionFromContext[testState](ctx) + if ctxSess == nil { t.Error("expected session from context") return nil } - sess.UpdateCustom(func(s testState) testState { + ctxSess.UpdateCustom(func(s testState) testState { s.Counter = 42 return s }) - retrievedCounter = sess.Custom().Counter + retrievedCounter = ctxSess.Custom().Counter return nil }) }, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -563,19 +557,19 @@ func TestSessionFlow_SessionContext(t *testing.T) { } } -func TestSessionFlow_ErrorInTurn(t *testing.T) { +func TestAgentFlow_ErrorInTurn(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - sf := DefineSessionFlow(reg, "errorFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "errorFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { return fmt.Errorf("turn failed") }) }, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -589,14 +583,13 @@ func TestSessionFlow_ErrorInTurn(t *testing.T) { } } -func TestSessionFlow_SetMessages(t *testing.T) { +func TestAgentFlow_SetMessages(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - sf := DefineSessionFlow(reg, "setMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "setMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Replace all messages with just one. sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) return nil @@ -604,7 +597,7 @@ func TestSessionFlow_SetMessages(t *testing.T) { }, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -631,23 +624,22 @@ func TestSessionFlow_SetMessages(t *testing.T) { } } -func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { +func TestAgentFlow_SnapshotIDInMessageMetadata(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySnapshotStore[testState]() + store := NewInMemorySessionStore[testState]() - sf := DefineSessionFlow(reg, "metadataFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess := params.Session + af := DefineCustomAgent(reg, "metadataFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil }) }, - WithSnapshotStore[testState](store), + WithSessionStore[testState](store), ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -682,9 +674,9 @@ func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { } } -func TestInMemorySnapshotStore(t *testing.T) { +func TestInMemorySessionStore(t *testing.T) { ctx := context.Background() - store := NewInMemorySnapshotStore[testState]() + store := NewInMemorySessionStore[testState]() // Get non-existent. snap, err := store.GetSnapshot(ctx, "nonexistent") @@ -726,20 +718,20 @@ func TestInMemorySnapshotStore(t *testing.T) { } } -func TestSessionFlow_SnapshotStoreReflectionAction(t *testing.T) { +func TestAgentFlow_SessionStoreReflectionAction(t *testing.T) { _ = context.Background() reg := newTestRegistry(t) - store := NewInMemorySnapshotStore[testState]() + store := NewInMemorySessionStore[testState]() - DefineSessionFlow(reg, "reflectFlow", - func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + DefineCustomAgent(reg, "reflectFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { return nil }, - WithSnapshotStore[testState](store), + WithSessionStore[testState](store), ) // The getSnapshot action should be registered. - action := reg.LookupAction("/snapshot-store/reflectFlow/getSnapshot") + action := reg.LookupAction("/session-store/reflectFlow/getSnapshot") if action == nil { t.Fatal("expected getSnapshot action to be registered") } diff --git a/go/ai/x/option.go b/go/ai/x/option.go index d7380b4b83..eaa456123b 100644 --- a/go/ai/x/option.go +++ b/go/ai/x/option.go @@ -18,22 +18,22 @@ package aix import "errors" -// --- SessionFlowOption --- +// --- AgentFlowOption --- -// SessionFlowOption configures a SessionFlow. -type SessionFlowOption[State any] interface { - applySessionFlow(*sessionFlowOptions[State]) error +// AgentFlowOption configures an AgentFlow. +type AgentFlowOption[State any] interface { + applyAgentFlow(*agentFlowOptions[State]) error } -type sessionFlowOptions[State any] struct { - store SnapshotStore[State] +type agentFlowOptions[State any] struct { + store SessionStore[State] callback SnapshotCallback[State] } -func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error { +func (o *agentFlowOptions[State]) applyAgentFlow(opts *agentFlowOptions[State]) error { if o.store != nil { if opts.store != nil { - return errors.New("cannot set snapshot store more than once (WithSnapshotStore)") + return errors.New("cannot set session store more than once (WithSessionStore)") } opts.store = o.store } @@ -46,15 +46,15 @@ func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[St return nil } -// WithSnapshotStore sets the store for persisting snapshots. -func WithSnapshotStore[State any](store SnapshotStore[State]) SessionFlowOption[State] { - return &sessionFlowOptions[State]{store: store} +// WithSessionStore sets the store for persisting snapshots. +func WithSessionStore[State any](store SessionStore[State]) AgentFlowOption[State] { + return &agentFlowOptions[State]{store: store} } // WithSnapshotCallback configures when snapshots are created. // If not provided and a store is configured, snapshots are always created. -func WithSnapshotCallback[State any](cb SnapshotCallback[State]) SessionFlowOption[State] { - return &sessionFlowOptions[State]{callback: cb} +func WithSnapshotCallback[State any](cb SnapshotCallback[State]) AgentFlowOption[State] { + return &agentFlowOptions[State]{callback: cb} } // --- StreamBidiOption --- @@ -75,12 +75,18 @@ func (o *streamBidiOptions[State]) applyStreamBidi(opts *streamBidiOptions[State if opts.state != nil { return errors.New("cannot set state more than once (WithState)") } + if opts.snapshotID != "" { + return errors.New("WithState and WithSnapshotID are mutually exclusive") + } opts.state = o.state } if o.snapshotID != "" { if opts.snapshotID != "" { return errors.New("cannot set snapshot ID more than once (WithSnapshotID)") } + if opts.state != nil { + return errors.New("WithSnapshotID and WithState are mutually exclusive") + } opts.snapshotID = o.snapshotID } if o.promptInput != nil { @@ -104,8 +110,8 @@ func WithSnapshotID[State any](id string) StreamBidiOption[State] { return &streamBidiOptions[State]{snapshotID: id} } -// WithPromptInput overrides the default prompt input for a prompt-backed session flow. -// Used with DefineSessionFlowFromPrompt to customize the prompt rendering per invocation. +// WithPromptInput overrides the default prompt input for a prompt-backed agent flow. +// Used with DefinePromptAgent to customize the prompt rendering per invocation. func WithPromptInput[State any](input any) StreamBidiOption[State] { return &streamBidiOptions[State]{promptInput: input} } diff --git a/go/ai/x/prompt_session_flow_test.go b/go/ai/x/prompt_agent_test.go similarity index 92% rename from go/ai/x/prompt_session_flow_test.go rename to go/ai/x/prompt_agent_test.go index 70dade4e8c..1de3043e1b 100644 --- a/go/ai/x/prompt_session_flow_test.go +++ b/go/ai/x/prompt_agent_test.go @@ -65,7 +65,7 @@ func setupPromptTestRegistry(t *testing.T) *registry.Registry { return reg } -func TestPromptSessionFlow_Basic(t *testing.T) { +func TestPromptAgent_Basic(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) @@ -74,11 +74,11 @@ func TestPromptSessionFlow_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - sf := DefineSessionFlowFromPrompt[testState]( + af := DefinePromptAgent[testState]( reg, "promptFlow", prompt, nil, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -133,7 +133,7 @@ func TestPromptSessionFlow_Basic(t *testing.T) { } } -func TestPromptSessionFlow_PromptInputOverride(t *testing.T) { +func TestPromptAgent_PromptInputOverride(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) @@ -146,12 +146,12 @@ func TestPromptSessionFlow_PromptInputOverride(t *testing.T) { ai.WithPrompt("Hello {{name}}!"), ) - sf := DefineSessionFlowFromPrompt[testState]( + af := DefinePromptAgent[testState]( reg, "promptInputFlow", prompt, greetInput{Name: "default"}, ) // Use WithPromptInput to override. - conn, err := sf.StreamBidi(ctx, + conn, err := af.StreamBidi(ctx, WithPromptInput[testState](greetInput{Name: "override"}), ) if err != nil { @@ -194,7 +194,7 @@ func TestPromptSessionFlow_PromptInputOverride(t *testing.T) { } } -func TestPromptSessionFlow_MultiTurnHistory(t *testing.T) { +func TestPromptAgent_MultiTurnHistory(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) @@ -223,11 +223,11 @@ func TestPromptSessionFlow_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - sf := DefineSessionFlowFromPrompt[testState]( + af := DefinePromptAgent[testState]( reg, "historyFlow", prompt, nil, ) - conn, err := sf.StreamBidi(ctx) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -289,23 +289,23 @@ func TestPromptSessionFlow_MultiTurnHistory(t *testing.T) { } } -func TestPromptSessionFlow_SnapshotPersistsPromptInput(t *testing.T) { +func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) - store := NewInMemorySnapshotStore[testState]() + store := NewInMemorySessionStore[testState]() prompt := ai.DefinePrompt(reg, "snapPrompt", ai.WithModelName("test/echo"), ai.WithSystem("You are a test assistant."), ) - sf := DefineSessionFlowFromPrompt[testState]( + af := DefinePromptAgent( reg, "snapPromptFlow", prompt, nil, - WithSnapshotStore(store), + WithSessionStore(store), ) // Start with prompt input. - conn, err := sf.StreamBidi(ctx, + conn, err := af.StreamBidi(ctx, WithPromptInput[testState](map[string]any{"key": "value"}), ) if err != nil { @@ -342,7 +342,7 @@ func TestPromptSessionFlow_SnapshotPersistsPromptInput(t *testing.T) { } // Resume from snapshot — the PromptInput should be preserved. - conn2, err := sf.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) + conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) if err != nil { t.Fatalf("StreamBidi with snapshot failed: %v", err) } diff --git a/go/ai/x/snapshot.go b/go/ai/x/snapshot.go index e9411aa32f..864327fe76 100644 --- a/go/ai/x/snapshot.go +++ b/go/ai/x/snapshot.go @@ -34,8 +34,8 @@ type SessionState[State any] struct { // Custom is the user-defined state associated with this conversation. Custom State `json:"custom,omitempty"` // Artifacts are named collections of parts produced during the conversation. - Artifacts []*SessionFlowArtifact `json:"artifacts,omitempty"` - // PromptInput is the input used for prompt rendering in prompt-backed session flows. + Artifacts []*AgentArtifact `json:"artifacts,omitempty"` + // PromptInput is the input used for prompt rendering in prompt-backed agent flows. // Stored as any to support type-erased persistence across snapshot boundaries. PromptInput any `json:"promptInput,omitempty"` } @@ -82,29 +82,29 @@ type SnapshotContext[State any] struct { // If not provided and a store is configured, snapshots are always created. type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool -// SnapshotStore persists and retrieves snapshots. -type SnapshotStore[State any] interface { +// SessionStore persists and retrieves snapshots. +type SessionStore[State any] interface { // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) // SaveSnapshot persists a snapshot. SaveSnapshot(ctx context.Context, snapshot *SessionSnapshot[State]) error } -// InMemorySnapshotStore provides a thread-safe in-memory snapshot store. -type InMemorySnapshotStore[State any] struct { +// InMemorySessionStore provides a thread-safe in-memory snapshot store. +type InMemorySessionStore[State any] struct { snapshots map[string]*SessionSnapshot[State] mu sync.RWMutex } -// NewInMemorySnapshotStore creates a new in-memory snapshot store. -func NewInMemorySnapshotStore[State any]() *InMemorySnapshotStore[State] { - return &InMemorySnapshotStore[State]{ +// NewInMemorySessionStore creates a new in-memory snapshot store. +func NewInMemorySessionStore[State any]() *InMemorySessionStore[State] { + return &InMemorySessionStore[State]{ snapshots: make(map[string]*SessionSnapshot[State]), } } // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. -func (s *InMemorySnapshotStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { +func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { s.mu.RLock() defer s.mu.RUnlock() @@ -121,7 +121,7 @@ func (s *InMemorySnapshotStore[State]) GetSnapshot(_ context.Context, snapshotID } // SaveSnapshot persists a snapshot. -func (s *InMemorySnapshotStore[State]) SaveSnapshot(_ context.Context, snapshot *SessionSnapshot[State]) error { +func (s *InMemorySessionStore[State]) SaveSnapshot(_ context.Context, snapshot *SessionSnapshot[State]) error { s.mu.Lock() defer s.mu.Unlock() @@ -150,8 +150,8 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta } // SnapshotOn returns a SnapshotCallback that only allows snapshots for the -// specified events. For example, SnapshotOn[MyState](TurnEnd) will skip the -// invocation-end snapshot. +// specified events. For example, SnapshotOn[MyState](SnapshotEventTurnEnd) +// will skip the invocation-end snapshot. func SnapshotOn[State any](events ...SnapshotEvent) SnapshotCallback[State] { set := make(map[SnapshotEvent]struct{}, len(events)) for _, e := range events { diff --git a/go/core/api/action.go b/go/core/api/action.go index 3150fe88ae..e79ba558f0 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -64,8 +64,8 @@ const ( ActionTypeCustom ActionType = "custom" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" - ActionTypeSessionFlow ActionType = "session-flow" - ActionTypeSnapshotStore ActionType = "snapshot-store" + ActionTypeAgentFlow ActionType = "agent-flow" + ActionTypeSessionStore ActionType = "session-store" ) // ActionDesc is a descriptor of an action. diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 78298c15af..77b182e1ad 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -408,10 +408,10 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B return core.DefineBidiFlow(g.reg, name, fn) } -// DefineSessionFlow creates a SessionFlow with automatic snapshot management +// DefineCustomAgent creates an AgentFlow with automatic snapshot management // and registers it as a flow action. // -// A SessionFlow is a stateful, multi-turn conversational flow with automatic +// An AgentFlow is a stateful, multi-turn conversational flow with automatic // snapshot persistence and turn semantics. It builds on bidirectional streaming // to enable ongoing conversations with managed state. // @@ -429,25 +429,25 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B // Phase string `json:"phase"` // } // -// chatFlow := genkit.DefineSessionFlow(g, "chatFlow", -// func(ctx context.Context, resp aix.Responder[ChatStatus], params *aix.SessionFlowParams[ChatState]) error { -// return params.Session.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { +// chatFlow := genkit.DefineCustomAgent(g, "chatFlow", +// func(ctx context.Context, resp aix.Responder[ChatStatus], sess *aix.AgentSession[ChatState]) error { +// return sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { // // ... handle each turn ... // return nil // }) // }, -// aix.WithSnapshotStore(store), +// aix.WithSessionStore(store), // ) -func DefineSessionFlow[Stream, State any]( +func DefineCustomAgent[Stream, State any]( g *Genkit, name string, - fn aix.SessionFlowFunc[Stream, State], - opts ...aix.SessionFlowOption[State], -) *aix.SessionFlow[Stream, State] { - return aix.DefineSessionFlow(g.reg, name, fn, opts...) + fn aix.AgentFlowFunc[Stream, State], + opts ...aix.AgentFlowOption[State], +) *aix.AgentFlow[Stream, State] { + return aix.DefineCustomAgent(g.reg, name, fn, opts...) } -// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an +// DefinePromptAgent creates a prompt-backed AgentFlow with an // automatic conversation loop. Each turn renders the prompt, appends // conversation history, calls the model with streaming, and updates session state. // @@ -457,14 +457,14 @@ func DefineSessionFlow[Stream, State any]( // Type parameters: // - State: Type for user-defined state in snapshots // - PromptIn: The prompt input type (inferred from the PromptRenderer) -func DefineSessionFlowFromPrompt[State, PromptIn any]( +func DefinePromptAgent[State, PromptIn any]( g *Genkit, name string, p aix.PromptRenderer[PromptIn], defaultInput PromptIn, - opts ...aix.SessionFlowOption[State], -) *aix.SessionFlow[struct{}, State] { - return aix.DefineSessionFlowFromPrompt(g.reg, name, p, defaultInput, opts...) + opts ...aix.AgentFlowOption[State], +) *aix.AgentFlow[struct{}, State] { + return aix.DefinePromptAgent(g.reg, name, p, defaultInput, opts...) } // Run executes the given function `fn` within the context of the current flow run, diff --git a/go/samples/basic-session-flow/main.go b/go/samples/custom-agent/main.go similarity index 79% rename from go/samples/basic-session-flow/main.go rename to go/samples/custom-agent/main.go index 2e2fe6aa6e..b7b73c51dc 100644 --- a/go/samples/basic-session-flow/main.go +++ b/go/samples/custom-agent/main.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates the SessionFlow API for multi-turn conversation +// This sample demonstrates the AgentFlow API for multi-turn conversation // with token-level streaming. It runs a CLI REPL where conversation history // is managed automatically by the session. package main @@ -35,9 +35,9 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatFlow := genkit.DefineSessionFlow(g, "chat", - func(ctx context.Context, resp aix.Responder[any], params *aix.SessionFlowParams[any]) error { - return params.Session.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { + chatFlow := genkit.DefineCustomAgent(g, "chat", + func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) error { + return sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ @@ -45,13 +45,13 @@ func main() { }, })), ai.WithSystem("You are a helpful assistant. Keep responses concise."), - ai.WithMessages(params.Session.Messages()...), + ai.WithMessages(sess.Messages()...), ) { if err != nil { return err } if chunk.Done { - params.Session.AddMessages(chunk.Response.Message) + sess.AddMessages(chunk.Response.Message) break } resp.SendChunk(chunk.Chunk) @@ -60,11 +60,11 @@ func main() { return nil }) }, - aix.WithSnapshotStore(aix.NewInMemorySnapshotStore[any]()), + aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(aix.SnapshotOn[any](aix.SnapshotEventTurnEnd)), ) - fmt.Println("Session Flow Chat (type 'quit' to exit)") + fmt.Println("Agent Flow Chat (type 'quit' to exit)") fmt.Println() conn, err := chatFlow.StreamBidi(ctx) @@ -101,8 +101,8 @@ func main() { if chunk.Chunk != nil { fmt.Print(chunk.Chunk.Text()) } - if chunk.SnapshotCreated != "" { - fmt.Printf("\n[snapshot: %s]", chunk.SnapshotCreated) + if chunk.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) } if chunk.EndTurn { fmt.Println() diff --git a/go/samples/prompt-session-flow/main.go b/go/samples/prompt-agent/main.go similarity index 81% rename from go/samples/prompt-session-flow/main.go rename to go/samples/prompt-agent/main.go index b85988c10d..95e83eed7c 100644 --- a/go/samples/prompt-session-flow/main.go +++ b/go/samples/prompt-agent/main.go @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates DefineSessionFlowFromPrompt, which creates a -// multi-turn conversational session flow backed by a .prompt file. The +// This sample demonstrates DefinePromptAgent, which creates a +// multi-turn conversational agent flow backed by a .prompt file. The // conversation loop (render prompt, call model, stream chunks, update history) -// is handled automatically. Compare with basic-session-flow which wires +// is handled automatically. Compare with custom-agent which wires // the same loop manually. package main @@ -41,15 +41,15 @@ func main() { chatPrompt := genkit.LookupDataPrompt[ChatPromptInput, string](g, "chat") - chatFlow := genkit.DefineSessionFlowFromPrompt( + chatFlow := genkit.DefinePromptAgent( g, "chat", chatPrompt, ChatPromptInput{Personality: "a sarcastic pirate"}, - aix.WithSnapshotStore(aix.NewInMemorySnapshotStore[any]()), + aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 }), ) - fmt.Println("Prompt Session Flow Chat (type 'quit' to exit)") + fmt.Println("Prompt Agent Chat (type 'quit' to exit)") fmt.Println() conn, err := chatFlow.StreamBidi(ctx) @@ -86,8 +86,8 @@ func main() { if chunk.Chunk != nil { fmt.Print(chunk.Chunk.Text()) } - if chunk.SnapshotCreated != "" { - fmt.Printf("\n[snapshot: %s]", chunk.SnapshotCreated) + if chunk.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) } if chunk.EndTurn { fmt.Println() diff --git a/go/samples/prompt-session-flow/prompts/chat.prompt b/go/samples/prompt-agent/prompts/chat.prompt similarity index 100% rename from go/samples/prompt-session-flow/prompts/chat.prompt rename to go/samples/prompt-agent/prompts/chat.prompt From db371022f53443daedd7ebe635fa1baa87ebc969 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 15:16:00 -0800 Subject: [PATCH 014/141] Update action.go --- go/core/api/action.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index e79ba558f0..a243474bee 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -64,8 +64,8 @@ const ( ActionTypeCustom ActionType = "custom" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" - ActionTypeAgentFlow ActionType = "agent-flow" - ActionTypeSessionStore ActionType = "session-store" + ActionTypeAgentFlow ActionType = "agent-flow" + ActionTypeSessionStore ActionType = "session-store" ) // ActionDesc is a descriptor of an action. From f4c4ec1a0f61311c8c20304433974a246b0ab866 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 18:35:20 -0800 Subject: [PATCH 015/141] added stream capturing to output --- go/ai/x/agent_flow.go | 330 ++++++------------ go/ai/x/agent_flow_test.go | 504 +++++++++++++++++++++++++++- go/ai/x/prompt_agent_test.go | 375 --------------------- go/ai/x/{snapshot.go => session.go} | 159 ++++++++- go/core/flow.go | 2 +- 5 files changed, 755 insertions(+), 615 deletions(-) delete mode 100644 go/ai/x/prompt_agent_test.go rename go/ai/x/{snapshot.go => session.go} (55%) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index 900cef135d..90ce23fc4b 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -22,7 +22,6 @@ package aix import ( "context" - "encoding/json" "fmt" "iter" "log/slog" @@ -95,154 +94,6 @@ type AgentFlowStreamChunk[Stream any] struct { EndTurn bool `json:"endTurn,omitempty"` } -// --- Session --- - -// Session holds conversation state and provides thread-safe read/write access to messages, -// input variables, custom state, and artifacts. -type Session[State any] struct { - mu sync.RWMutex - state SessionState[State] - store SessionStore[State] -} - -// State returns a copy of the current state. -func (s *Session[State]) State() *SessionState[State] { - s.mu.RLock() - defer s.mu.RUnlock() - copied := s.copyStateLocked() - return &copied -} - -// Messages returns the current conversation history. -func (s *Session[State]) Messages() []*ai.Message { - s.mu.RLock() - defer s.mu.RUnlock() - msgs := make([]*ai.Message, len(s.state.Messages)) - copy(msgs, s.state.Messages) - return msgs -} - -// AddMessages appends messages to the conversation history. -func (s *Session[State]) AddMessages(messages ...*ai.Message) { - s.mu.Lock() - defer s.mu.Unlock() - s.state.Messages = append(s.state.Messages, messages...) -} - -// SetMessages replaces the entire conversation history. -func (s *Session[State]) SetMessages(messages []*ai.Message) { - s.mu.Lock() - defer s.mu.Unlock() - s.state.Messages = messages -} - -// Custom returns the current user-defined custom state. -func (s *Session[State]) Custom() State { - s.mu.RLock() - defer s.mu.RUnlock() - return s.state.Custom -} - -// SetCustom updates the user-defined custom state. -func (s *Session[State]) SetCustom(custom State) { - s.mu.Lock() - defer s.mu.Unlock() - s.state.Custom = custom -} - -// UpdateCustom atomically reads the current custom state, applies the given -// function, and writes the result back. -func (s *Session[State]) UpdateCustom(fn func(State) State) { - s.mu.Lock() - defer s.mu.Unlock() - s.state.Custom = fn(s.state.Custom) -} - -// PromptInput returns the prompt input stored in the session state. -func (s *Session[State]) PromptInput() any { - s.mu.RLock() - defer s.mu.RUnlock() - return s.state.PromptInput -} - -// Artifacts returns the current artifacts. -func (s *Session[State]) Artifacts() []*AgentArtifact { - s.mu.RLock() - defer s.mu.RUnlock() - arts := make([]*AgentArtifact, len(s.state.Artifacts)) - copy(arts, s.state.Artifacts) - return arts -} - -// AddArtifacts adds artifacts to the session. If an artifact with the same -// name already exists, it is replaced. -func (s *Session[State]) AddArtifacts(artifacts ...*AgentArtifact) { - s.mu.Lock() - defer s.mu.Unlock() - for _, a := range artifacts { - replaced := false - if a.Name != "" { - for i, existing := range s.state.Artifacts { - if existing.Name == a.Name { - s.state.Artifacts[i] = a - replaced = true - break - } - } - } - if !replaced { - s.state.Artifacts = append(s.state.Artifacts, a) - } - } -} - -// SetArtifacts replaces the entire artifact list. -func (s *Session[State]) SetArtifacts(artifacts []*AgentArtifact) { - s.mu.Lock() - defer s.mu.Unlock() - s.state.Artifacts = artifacts -} - -// copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). -func (s *Session[State]) copyStateLocked() SessionState[State] { - bytes, err := json.Marshal(s.state) - if err != nil { - panic(fmt.Sprintf("agent flow: failed to marshal state: %v", err)) - } - var copied SessionState[State] - if err := json.Unmarshal(bytes, &copied); err != nil { - panic(fmt.Sprintf("agent flow: failed to unmarshal state: %v", err)) - } - return copied -} - -// --- Session context --- - -type sessionContextKey struct{} - -type sessionHolder struct { - session any -} - -// NewSessionContext returns a new context with the session attached. -func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { - return context.WithValue(ctx, sessionContextKey{}, &sessionHolder{session: s}) -} - -// SessionFromContext retrieves the current session from context. -// Returns nil if no session is in context or if the type doesn't match. -func SessionFromContext[State any](ctx context.Context) *Session[State] { - holder, ok := ctx.Value(sessionContextKey{}).(*sessionHolder) - if !ok || holder == nil { - return nil - } - session, ok := holder.session.(*Session[State]) - if !ok { - return nil - } - return session -} - // --- AgentSession --- // AgentSession extends Session with agent-flow-specific functionality: @@ -254,6 +105,10 @@ type AgentSession[State any] struct { inCh <-chan *AgentFlowInput lastSnapshot *SessionSnapshot[State] turnIndex int + // collectTurnOutput returns the accumulated stream chunks for the current + // turn and resets the accumulator. Set by runWrapped; nil if no collection + // is configured (e.g., standalone usage). + collectTurnOutput func() any } // Run loops over the input channel, calling fn for each turn. Each turn is @@ -269,16 +124,20 @@ func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Conte } _, err := tracing.RunInNewSpan(ctx, spanMeta, input, - func(ctx context.Context, input *AgentFlowInput) (struct{}, error) { + func(ctx context.Context, input *AgentFlowInput) (any, error) { a.AddMessages(input.Messages...) if err := fn(ctx, input); err != nil { - return struct{}{}, err + return nil, err } a.onEndTurn(ctx) a.turnIndex++ - return struct{}{}, nil + + if a.collectTurnOutput != nil { + return a.collectTurnOutput(), nil + } + return nil, nil }, ) if err != nil { @@ -375,19 +234,13 @@ func (r Responder[Stream]) SendArtifact(artifact *AgentArtifact) { r <- &AgentFlowStreamChunk[Stream]{Artifact: artifact} } -// --- AgentFlowFunc --- +// --- AgentFlow --- // AgentFlowFunc is the function signature for agent flows. // Type parameters: // - Stream: Type for status updates sent via the responder // - State: Type for user-defined state in snapshots -type AgentFlowFunc[Stream, State any] func( - ctx context.Context, - resp Responder[Stream], - sess *AgentSession[State], -) error - -// --- AgentFlow --- +type AgentFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *AgentSession[State]) error // AgentFlow is a bidirectional streaming flow with automatic snapshot management. type AgentFlow[Stream, State any] struct { @@ -434,6 +287,71 @@ func DefineCustomAgent[Stream, State any]( return af } +// PromptRenderer renders a prompt with typed input into GenerateActionOptions. +// This interface is satisfied by both ai.Prompt (with In=any) and +// *ai.DataPrompt[In, Out]. +type PromptRenderer[In any] interface { + Render(ctx context.Context, input In) (*ai.GenerateActionOptions, error) +} + +// DefinePromptAgent creates a prompt-backed AgentFlow with an +// automatic conversation loop. Each turn renders the prompt, appends +// conversation history, calls GenerateWithRequest, streams chunks to the +// client, and adds the model response to the session. +// +// The defaultInput is used for prompt rendering unless overridden per +// invocation via WithPromptInput. +func DefinePromptAgent[State, PromptIn any]( + r api.Registry, + name string, + p PromptRenderer[PromptIn], + defaultInput PromptIn, + opts ...AgentFlowOption[State], +) *AgentFlow[any, State] { + fn := func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) error { + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + // Resolve prompt input: session state override > default. + promptInput := defaultInput + if stored := sess.InputVariables(); stored != nil { + typed, ok := stored.(PromptIn) + if !ok { + return fmt.Errorf("prompt input type mismatch: got %T, want %T", stored, promptInput) + } + promptInput = typed + } + + // Render the prompt template. + actionOpts, err := p.Render(ctx, promptInput) + if err != nil { + return fmt.Errorf("prompt render: %w", err) + } + + // Append conversation history after the prompt-rendered messages. + actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) + + // Call the model with streaming. + modelResp, err := ai.GenerateWithRequest(ctx, r, actionOpts, nil, + func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + resp.SendChunk(chunk) + return nil + }, + ) + if err != nil { + return fmt.Errorf("generate: %w", err) + } + + // Add the model response message to session history. + if modelResp.Message != nil { + sess.AddMessages(modelResp.Message) + } + + return nil + }) + } + + return DefineCustomAgent(r, name, fn, opts...) +} + // StreamBidi starts a new agent flow invocation. func (af *AgentFlow[Stream, State]) StreamBidi( ctx context.Context, @@ -485,7 +403,22 @@ func (af *AgentFlow[Stream, State]) runWrapped( agentSess.turnIndex = snapshot.TurnIndex } - // Intermediary channel: intercepts artifacts before forwarding to outCh. + // Turn output accumulator: collects content chunks per turn for span output. + var ( + turnMu sync.Mutex + turnChunks []*AgentFlowStreamChunk[Stream] + ) + + agentSess.collectTurnOutput = func() any { + turnMu.Lock() + defer turnMu.Unlock() + result := turnChunks + turnChunks = nil + return result + } + + // Intermediary channel: intercepts artifacts, accumulates turn output, + // and forwards to outCh. respCh := make(chan *AgentFlowStreamChunk[Stream]) var wg sync.WaitGroup wg.Add(1) @@ -495,6 +428,12 @@ func (af *AgentFlow[Stream, State]) runWrapped( if chunk.Artifact != nil { session.AddArtifacts(chunk.Artifact) } + // Accumulate content chunks (exclude control signals from onEndTurn). + if !chunk.EndTurn && chunk.SnapshotID == "" { + turnMu.Lock() + turnChunks = append(turnChunks, chunk) + turnMu.Unlock() + } outCh <- chunk } }() @@ -554,7 +493,7 @@ func newSessionFromInit[State any]( s.state = *init.State } if init.InputVariables != nil { - s.state.PromptInput = init.InputVariables + s.state.InputVariables = init.InputVariables } } @@ -662,70 +601,3 @@ func (c *AgentFlowConnection[Stream, State]) Output() (*AgentFlowOutput[State], func (c *AgentFlowConnection[Stream, State]) Done() <-chan struct{} { return c.conn.Done() } - -// --- Prompt-backed AgentFlow --- - -// PromptRenderer renders a prompt with typed input into GenerateActionOptions. -// This interface is satisfied by both ai.Prompt (with In=any) and -// *ai.DataPrompt[In, Out]. -type PromptRenderer[In any] interface { - Render(ctx context.Context, input In) (*ai.GenerateActionOptions, error) -} - -// DefinePromptAgent creates a prompt-backed AgentFlow with an -// automatic conversation loop. Each turn renders the prompt, appends -// conversation history, calls GenerateWithRequest, streams chunks to the -// client, and adds the model response to the session. -// -// The defaultInput is used for prompt rendering unless overridden per -// invocation via WithPromptInput. -func DefinePromptAgent[State, PromptIn any]( - r api.Registry, - name string, - p PromptRenderer[PromptIn], - defaultInput PromptIn, - opts ...AgentFlowOption[State], -) *AgentFlow[struct{}, State] { - fn := func(ctx context.Context, resp Responder[struct{}], sess *AgentSession[State]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { - // Resolve prompt input: session state override > default. - promptInput := defaultInput - if stored := sess.PromptInput(); stored != nil { - typed, ok := stored.(PromptIn) - if !ok { - return fmt.Errorf("prompt input type mismatch: got %T, want %T", stored, promptInput) - } - promptInput = typed - } - - // Render the prompt template. - actionOpts, err := p.Render(ctx, promptInput) - if err != nil { - return fmt.Errorf("prompt render: %w", err) - } - - // Append conversation history after the prompt-rendered messages. - actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) - - // Call the model with streaming. - modelResp, err := ai.GenerateWithRequest(ctx, r, actionOpts, nil, - func(ctx context.Context, chunk *ai.ModelResponseChunk) error { - resp.SendChunk(chunk) - return nil - }, - ) - if err != nil { - return fmt.Errorf("generate: %w", err) - } - - // Add the model response message to session history. - if modelResp.Message != nil { - sess.AddMessages(modelResp.Message) - } - - return nil - }) - } - - return DefineCustomAgent(r, name, fn, opts...) -} diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index 8e877a5820..5238a54cb6 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -19,6 +19,7 @@ package aix import ( "context" "fmt" + "strings" "testing" "github.com/firebase/genkit/go/ai" @@ -132,7 +133,7 @@ func TestAgentFlow_WithSessionStore(t *testing.T) { return nil }) }, - WithSessionStore[testState](store), + WithSessionStore(store), ) conn, err := af.StreamBidi(ctx) @@ -205,7 +206,7 @@ func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { return nil }) }, - WithSessionStore[testState](store), + WithSessionStore(store), ) // First invocation: create a snapshot. @@ -424,7 +425,7 @@ func TestAgentFlow_SnapshotCallback(t *testing.T) { return nil }) }, - WithSessionStore[testState](store), + WithSessionStore(store), WithSnapshotCallback(func(ctx context.Context, sc *SnapshotContext[testState]) bool { callbackCalls++ return sc.TurnIndex%2 == 0 // only snapshot on even turns @@ -636,7 +637,7 @@ func TestAgentFlow_SnapshotIDInMessageMetadata(t *testing.T) { return nil }) }, - WithSessionStore[testState](store), + WithSessionStore(store), ) conn, err := af.StreamBidi(ctx) @@ -718,6 +719,150 @@ func TestInMemorySessionStore(t *testing.T) { } } +func TestAgentFlow_TurnSpanOutput(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + var capturedOutputs []any + + af := DefineCustomAgent(reg, "turnOutputFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + // Wrap collectTurnOutput to capture what each turn produces. + originalCollect := sess.collectTurnOutput + sess.collectTurnOutput = func() any { + output := originalCollect() + capturedOutputs = append(capturedOutputs, output) + return output + } + + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + resp.SendStatus(testStatus{Phase: "thinking"}) + resp.SendChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("reply")}, + }) + resp.SendArtifact(&AgentArtifact{ + Name: "out.txt", + Parts: []*ai.Part{ai.NewTextPart("content")}, + }) + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Two turns. + for turn := range 2 { + if err := conn.SendText(fmt.Sprintf("turn %d", turn)); err != nil { + t.Fatalf("SendText failed on turn %d: %v", turn, err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error on turn %d: %v", turn, err) + } + if chunk.EndTurn { + break + } + } + } + + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have captured output for each turn. + if len(capturedOutputs) != 2 { + t.Fatalf("expected 2 captured outputs, got %d", len(capturedOutputs)) + } + + for i, output := range capturedOutputs { + chunks, ok := output.([]*AgentFlowStreamChunk[testStatus]) + if !ok { + t.Fatalf("turn %d: expected []*AgentFlowStreamChunk[testStatus], got %T", i, output) + } + // 3 content chunks per turn: status + model chunk + artifact. + if len(chunks) != 3 { + t.Errorf("turn %d: expected 3 chunks, got %d", i, len(chunks)) + } + for j, chunk := range chunks { + if chunk.EndTurn { + t.Errorf("turn %d, chunk %d: EndTurn should not be in turn output", i, j) + } + if chunk.SnapshotID != "" { + t.Errorf("turn %d, chunk %d: SnapshotID should not be in turn output", i, j) + } + } + } +} + +func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + var capturedOutputs []any + + af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + originalCollect := sess.collectTurnOutput + sess.collectTurnOutput = func() any { + output := originalCollect() + capturedOutputs = append(capturedOutputs, output) + return output + } + + return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + resp.SendStatus(testStatus{Phase: "working"}) + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + var sawSnapshot bool + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.SnapshotID != "" { + sawSnapshot = true + } + if chunk.EndTurn { + break + } + } + conn.Close() + conn.Output() + + if !sawSnapshot { + t.Fatal("expected a snapshot chunk on the stream") + } + + // Turn output should contain only the status chunk, not the snapshot/endTurn. + if len(capturedOutputs) != 1 { + t.Fatalf("expected 1 captured output, got %d", len(capturedOutputs)) + } + chunks := capturedOutputs[0].([]*AgentFlowStreamChunk[testStatus]) + if len(chunks) != 1 { + t.Errorf("expected 1 content chunk, got %d", len(chunks)) + } + if chunks[0].Status.Phase != "working" { + t.Errorf("expected status phase 'working', got %q", chunks[0].Status.Phase) + } +} + func TestAgentFlow_SessionStoreReflectionAction(t *testing.T) { _ = context.Background() reg := newTestRegistry(t) @@ -727,7 +872,7 @@ func TestAgentFlow_SessionStoreReflectionAction(t *testing.T) { func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { return nil }, - WithSessionStore[testState](store), + WithSessionStore(store), ) // The getSnapshot action should be registered. @@ -736,3 +881,352 @@ func TestAgentFlow_SessionStoreReflectionAction(t *testing.T) { t.Fatal("expected getSnapshot action to be registered") } } + +// setupPromptTestRegistry creates a registry with an echo model and generate action. +func setupPromptTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + reg := registry.New() + ctx := context.Background() + + ai.ConfigureFormats(reg) + ai.DefineModel(reg, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Echo back the last user message text. + var text string + for i := len(req.Messages) - 1; i >= 0; i-- { + if req.Messages[i].Role == ai.RoleUser { + text = req.Messages[i].Text() + break + } + } + if text == "" { + text = "no input" + } + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("echo: " + text), + } + + if cb != nil { + if err := cb(ctx, &ai.ModelResponseChunk{ + Content: resp.Message.Content, + }); err != nil { + return nil, err + } + } + + return resp, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + return reg +} + +func TestPromptAgent_Basic(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + prompt := ai.DefinePrompt(reg, "testPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + af := DefinePromptAgent[testState]( + reg, "promptFlow", prompt, nil, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + + var gotChunk bool + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + gotChunk = true + } + if chunk.EndTurn { + break + } + } + if !gotChunk { + t.Error("expected at least one streaming chunk") + } + + // Turn 2. + if err := conn.SendText("world"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 user messages + 2 model replies = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptAgent_PromptInputOverride(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + type greetInput struct { + Name string `json:"name"` + } + + prompt := ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", + ai.WithModelName("test/echo"), + ai.WithPrompt("Hello {{name}}!"), + ) + + af := DefinePromptAgent[testState]( + reg, "promptInputFlow", prompt, greetInput{Name: "default"}, + ) + + // Use WithPromptInput to override. + conn, err := af.StreamBidi(ctx, + WithPromptInput[testState](greetInput{Name: "override"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Verify the override was stored in session state. + if response.State.InputVariables == nil { + t.Fatal("expected PromptInput in state") + } + + // The model echoes the last user message, which is "hi". + // But the prompt was rendered with "override" so "Hello override!" should appear + // in the messages sent to the model (verified via the echo). + // We primarily verify the state was set correctly. + inputMap, ok := response.State.InputVariables.(map[string]any) + if !ok { + t.Fatalf("expected PromptInput to be map[string]any, got %T", response.State.InputVariables) + } + if name, _ := inputMap["name"].(string); name != "override" { + t.Errorf("expected PromptInput name='override', got %q", name) + } +} + +func TestPromptAgent_MultiTurnHistory(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + // Use a model that echoes all message count so we can verify history grows. + ai.DefineModel(reg, "test/history", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Count total messages received (includes prompt-rendered + history). + var parts []string + for _, m := range req.Messages { + parts = append(parts, string(m.Role)+":"+m.Text()) + } + text := strings.Join(parts, "|") + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage(text), + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) + } + return resp, nil + }, + ) + + prompt := ai.DefinePrompt(reg, "historyPrompt", + ai.WithModelName("test/history"), + ai.WithSystem("system prompt"), + ) + + af := DefinePromptAgent[testState]( + reg, "historyFlow", prompt, nil, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + conn.SendText("turn1") + var turn1Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + turn1Response += chunk.Chunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 1 should have: system message + user message "turn1" (2 messages total from prompt + history). + // The system message comes from the prompt, "turn1" from session history. + if !strings.Contains(turn1Response, "turn1") { + t.Errorf("turn1 response should contain 'turn1', got: %s", turn1Response) + } + + // Turn 2. + conn.SendText("turn2") + var turn2Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + turn2Response += chunk.Chunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 2 should have: system + turn1 user + turn1 model reply + turn2 user (4 messages from prompt + history). + if !strings.Contains(turn2Response, "turn1") || !strings.Contains(turn2Response, "turn2") { + t.Errorf("turn2 response should contain both 'turn1' and 'turn2', got: %s", turn2Response) + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should have: turn1 user + turn1 model + turn2 user + turn2 model = 4 messages. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages in session, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + prompt := ai.DefinePrompt(reg, "snapPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + af := DefinePromptAgent( + reg, "snapPromptFlow", prompt, nil, + WithSessionStore(store), + ) + + // Start with prompt input. + conn, err := af.StreamBidi(ctx, + WithPromptInput[testState](map[string]any{"key": "value"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + resp, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + if resp.SnapshotID == "" { + t.Fatal("expected snapshot ID") + } + + // Verify the snapshot contains PromptInput. + snap, err := store.GetSnapshot(ctx, resp.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.State.InputVariables == nil { + t.Error("expected PromptInput in snapshot state") + } + + // Resume from snapshot — the PromptInput should be preserved. + conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) + if err != nil { + t.Fatalf("StreamBidi with snapshot failed: %v", err) + } + + conn2.SendText("continued") + for chunk, err := range conn2.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn2.Close() + + resp2, err := conn2.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have messages from both invocations. + if got := len(resp2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + + // PromptInput should still be present. + if resp2.State.InputVariables == nil { + t.Error("expected PromptInput preserved after resume") + } +} diff --git a/go/ai/x/prompt_agent_test.go b/go/ai/x/prompt_agent_test.go deleted file mode 100644 index 1de3043e1b..0000000000 --- a/go/ai/x/prompt_agent_test.go +++ /dev/null @@ -1,375 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package aix - -import ( - "context" - "strings" - "testing" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/internal/registry" -) - -// setupPromptTestRegistry creates a registry with an echo model and generate action. -func setupPromptTestRegistry(t *testing.T) *registry.Registry { - t.Helper() - reg := registry.New() - ctx := context.Background() - - ai.ConfigureFormats(reg) - ai.DefineModel(reg, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, - func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { - // Echo back the last user message text. - var text string - for i := len(req.Messages) - 1; i >= 0; i-- { - if req.Messages[i].Role == ai.RoleUser { - text = req.Messages[i].Text() - break - } - } - if text == "" { - text = "no input" - } - - resp := &ai.ModelResponse{ - Message: ai.NewModelTextMessage("echo: " + text), - } - - if cb != nil { - if err := cb(ctx, &ai.ModelResponseChunk{ - Content: resp.Message.Content, - }); err != nil { - return nil, err - } - } - - return resp, nil - }, - ) - ai.DefineGenerateAction(ctx, reg) - return reg -} - -func TestPromptAgent_Basic(t *testing.T) { - ctx := context.Background() - reg := setupPromptTestRegistry(t) - - prompt := ai.DefinePrompt(reg, "testPrompt", - ai.WithModelName("test/echo"), - ai.WithSystem("You are a test assistant."), - ) - - af := DefinePromptAgent[testState]( - reg, "promptFlow", prompt, nil, - ) - - conn, err := af.StreamBidi(ctx) - if err != nil { - t.Fatalf("StreamBidi failed: %v", err) - } - - // Turn 1. - if err := conn.SendText("hello"); err != nil { - t.Fatalf("SendText failed: %v", err) - } - - var gotChunk bool - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.Chunk != nil { - gotChunk = true - } - if chunk.EndTurn { - break - } - } - if !gotChunk { - t.Error("expected at least one streaming chunk") - } - - // Turn 2. - if err := conn.SendText("world"); err != nil { - t.Fatalf("SendText failed: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.EndTurn { - break - } - } - - conn.Close() - - response, err := conn.Output() - if err != nil { - t.Fatalf("Output failed: %v", err) - } - - // 2 user messages + 2 model replies = 4. - if got := len(response.State.Messages); got != 4 { - t.Errorf("expected 4 messages, got %d", got) - for i, m := range response.State.Messages { - t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) - } - } -} - -func TestPromptAgent_PromptInputOverride(t *testing.T) { - ctx := context.Background() - reg := setupPromptTestRegistry(t) - - type greetInput struct { - Name string `json:"name"` - } - - prompt := ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", - ai.WithModelName("test/echo"), - ai.WithPrompt("Hello {{name}}!"), - ) - - af := DefinePromptAgent[testState]( - reg, "promptInputFlow", prompt, greetInput{Name: "default"}, - ) - - // Use WithPromptInput to override. - conn, err := af.StreamBidi(ctx, - WithPromptInput[testState](greetInput{Name: "override"}), - ) - if err != nil { - t.Fatalf("StreamBidi failed: %v", err) - } - - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText failed: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.EndTurn { - break - } - } - conn.Close() - - response, err := conn.Output() - if err != nil { - t.Fatalf("Output failed: %v", err) - } - - // Verify the override was stored in session state. - if response.State.PromptInput == nil { - t.Fatal("expected PromptInput in state") - } - - // The model echoes the last user message, which is "hi". - // But the prompt was rendered with "override" so "Hello override!" should appear - // in the messages sent to the model (verified via the echo). - // We primarily verify the state was set correctly. - inputMap, ok := response.State.PromptInput.(map[string]any) - if !ok { - t.Fatalf("expected PromptInput to be map[string]any, got %T", response.State.PromptInput) - } - if name, _ := inputMap["name"].(string); name != "override" { - t.Errorf("expected PromptInput name='override', got %q", name) - } -} - -func TestPromptAgent_MultiTurnHistory(t *testing.T) { - ctx := context.Background() - reg := setupPromptTestRegistry(t) - - // Use a model that echoes all message count so we can verify history grows. - ai.DefineModel(reg, "test/history", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, - func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { - // Count total messages received (includes prompt-rendered + history). - var parts []string - for _, m := range req.Messages { - parts = append(parts, string(m.Role)+":"+m.Text()) - } - text := strings.Join(parts, "|") - - resp := &ai.ModelResponse{ - Message: ai.NewModelTextMessage(text), - } - if cb != nil { - cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) - } - return resp, nil - }, - ) - - prompt := ai.DefinePrompt(reg, "historyPrompt", - ai.WithModelName("test/history"), - ai.WithSystem("system prompt"), - ) - - af := DefinePromptAgent[testState]( - reg, "historyFlow", prompt, nil, - ) - - conn, err := af.StreamBidi(ctx) - if err != nil { - t.Fatalf("StreamBidi failed: %v", err) - } - - // Turn 1. - conn.SendText("turn1") - var turn1Response string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.Chunk != nil { - turn1Response += chunk.Chunk.Text() - } - if chunk.EndTurn { - break - } - } - - // Turn 1 should have: system message + user message "turn1" (2 messages total from prompt + history). - // The system message comes from the prompt, "turn1" from session history. - if !strings.Contains(turn1Response, "turn1") { - t.Errorf("turn1 response should contain 'turn1', got: %s", turn1Response) - } - - // Turn 2. - conn.SendText("turn2") - var turn2Response string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.Chunk != nil { - turn2Response += chunk.Chunk.Text() - } - if chunk.EndTurn { - break - } - } - - // Turn 2 should have: system + turn1 user + turn1 model reply + turn2 user (4 messages from prompt + history). - if !strings.Contains(turn2Response, "turn1") || !strings.Contains(turn2Response, "turn2") { - t.Errorf("turn2 response should contain both 'turn1' and 'turn2', got: %s", turn2Response) - } - - conn.Close() - - response, err := conn.Output() - if err != nil { - t.Fatalf("Output failed: %v", err) - } - - // Session should have: turn1 user + turn1 model + turn2 user + turn2 model = 4 messages. - if got := len(response.State.Messages); got != 4 { - t.Errorf("expected 4 messages in session, got %d", got) - for i, m := range response.State.Messages { - t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) - } - } -} - -func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { - ctx := context.Background() - reg := setupPromptTestRegistry(t) - store := NewInMemorySessionStore[testState]() - - prompt := ai.DefinePrompt(reg, "snapPrompt", - ai.WithModelName("test/echo"), - ai.WithSystem("You are a test assistant."), - ) - - af := DefinePromptAgent( - reg, "snapPromptFlow", prompt, nil, - WithSessionStore(store), - ) - - // Start with prompt input. - conn, err := af.StreamBidi(ctx, - WithPromptInput[testState](map[string]any{"key": "value"}), - ) - if err != nil { - t.Fatalf("StreamBidi failed: %v", err) - } - - conn.SendText("hello") - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.EndTurn { - break - } - } - conn.Close() - - resp, err := conn.Output() - if err != nil { - t.Fatalf("Output failed: %v", err) - } - - if resp.SnapshotID == "" { - t.Fatal("expected snapshot ID") - } - - // Verify the snapshot contains PromptInput. - snap, err := store.GetSnapshot(ctx, resp.SnapshotID) - if err != nil { - t.Fatalf("GetSnapshot failed: %v", err) - } - if snap.State.PromptInput == nil { - t.Error("expected PromptInput in snapshot state") - } - - // Resume from snapshot — the PromptInput should be preserved. - conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) - if err != nil { - t.Fatalf("StreamBidi with snapshot failed: %v", err) - } - - conn2.SendText("continued") - for chunk, err := range conn2.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.EndTurn { - break - } - } - conn2.Close() - - resp2, err := conn2.Output() - if err != nil { - t.Fatalf("Output failed: %v", err) - } - - // Should have messages from both invocations. - if got := len(resp2.State.Messages); got != 4 { - t.Errorf("expected 4 messages after resume, got %d", got) - } - - // PromptInput should still be present. - if resp2.State.PromptInput == nil { - t.Error("expected PromptInput preserved after resume") - } -} diff --git a/go/ai/x/snapshot.go b/go/ai/x/session.go similarity index 55% rename from go/ai/x/snapshot.go rename to go/ai/x/session.go index 864327fe76..17522eacdb 100644 --- a/go/ai/x/snapshot.go +++ b/go/ai/x/session.go @@ -19,6 +19,7 @@ package aix import ( "context" "encoding/json" + "fmt" "sync" "time" @@ -35,9 +36,9 @@ type SessionState[State any] struct { Custom State `json:"custom,omitempty"` // Artifacts are named collections of parts produced during the conversation. Artifacts []*AgentArtifact `json:"artifacts,omitempty"` - // PromptInput is the input used for prompt rendering in prompt-backed agent flows. - // Stored as any to support type-erased persistence across snapshot boundaries. - PromptInput any `json:"promptInput,omitempty"` + // InputVariables is the input used for agent flows that require input variables + // (e.g. prompt-backed agent flows). + InputVariables any `json:"inputVariables,omitempty"` } // SnapshotEvent identifies what triggered a snapshot. @@ -150,8 +151,8 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta } // SnapshotOn returns a SnapshotCallback that only allows snapshots for the -// specified events. For example, SnapshotOn[MyState](SnapshotEventTurnEnd) -// will skip the invocation-end snapshot. +// specified events. For example, SnapshotOn[MyState](TurnEnd) will skip the +// invocation-end snapshot. func SnapshotOn[State any](events ...SnapshotEvent) SnapshotCallback[State] { set := make(map[SnapshotEvent]struct{}, len(events)) for _, e := range events { @@ -162,3 +163,151 @@ func SnapshotOn[State any](events ...SnapshotEvent) SnapshotCallback[State] { return ok } } + +// --- Session --- + +// Session holds conversation state and provides thread-safe read/write access to messages, +// input variables, custom state, and artifacts. +type Session[State any] struct { + mu sync.RWMutex + state SessionState[State] + store SessionStore[State] +} + +// State returns a copy of the current state. +func (s *Session[State]) State() *SessionState[State] { + s.mu.RLock() + defer s.mu.RUnlock() + copied := s.copyStateLocked() + return &copied +} + +// Messages returns the current conversation history. +func (s *Session[State]) Messages() []*ai.Message { + s.mu.RLock() + defer s.mu.RUnlock() + msgs := make([]*ai.Message, len(s.state.Messages)) + copy(msgs, s.state.Messages) + return msgs +} + +// AddMessages appends messages to the conversation history. +func (s *Session[State]) AddMessages(messages ...*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = append(s.state.Messages, messages...) +} + +// SetMessages replaces the entire conversation history. +func (s *Session[State]) SetMessages(messages []*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = messages +} + +// Custom returns the current user-defined custom state. +func (s *Session[State]) Custom() State { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.Custom +} + +// SetCustom updates the user-defined custom state. +func (s *Session[State]) SetCustom(custom State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Custom = custom +} + +// UpdateCustom atomically reads the current custom state, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateCustom(fn func(State) State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Custom = fn(s.state.Custom) +} + +// InputVariables returns the prompt input stored in the session state. +func (s *Session[State]) InputVariables() any { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.InputVariables +} + +// Artifacts returns the current artifacts. +func (s *Session[State]) Artifacts() []*AgentArtifact { + s.mu.RLock() + defer s.mu.RUnlock() + arts := make([]*AgentArtifact, len(s.state.Artifacts)) + copy(arts, s.state.Artifacts) + return arts +} + +// AddArtifacts adds artifacts to the session. If an artifact with the same +// name already exists, it is replaced. +func (s *Session[State]) AddArtifacts(artifacts ...*AgentArtifact) { + s.mu.Lock() + defer s.mu.Unlock() + for _, a := range artifacts { + replaced := false + if a.Name != "" { + for i, existing := range s.state.Artifacts { + if existing.Name == a.Name { + s.state.Artifacts[i] = a + replaced = true + break + } + } + } + if !replaced { + s.state.Artifacts = append(s.state.Artifacts, a) + } + } +} + +// SetArtifacts replaces the entire artifact list. +func (s *Session[State]) SetArtifacts(artifacts []*AgentArtifact) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Artifacts = artifacts +} + +// copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). +func (s *Session[State]) copyStateLocked() SessionState[State] { + bytes, err := json.Marshal(s.state) + if err != nil { + panic(fmt.Sprintf("agent flow: failed to marshal state: %v", err)) + } + var copied SessionState[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + panic(fmt.Sprintf("agent flow: failed to unmarshal state: %v", err)) + } + return copied +} + +// --- Session context --- + +type sessionContextKey struct{} + +type sessionHolder struct { + session any +} + +// NewSessionContext returns a new context with the session attached. +func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { + return context.WithValue(ctx, sessionContextKey{}, &sessionHolder{session: s}) +} + +// SessionFromContext retrieves the current session from context. +// Returns nil if no session is in context or if the type doesn't match. +func SessionFromContext[State any](ctx context.Context) *Session[State] { + holder, ok := ctx.Value(sessionContextKey{}).(*sessionHolder) + if !ok || holder == nil { + return nil + } + session, ok := holder.session.(*Session[State]) + if !ok { + return nil + } + return session +} diff --git a/go/core/flow.go b/go/core/flow.go index c173a0306c..b220c30479 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -93,7 +93,7 @@ func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, St // DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. // Flow context is injected so that [Run] works inside the bidi function. func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { - f := NewBidiFlow[In, Out, Stream, Init](name, fn) + f := NewBidiFlow(name, fn) f.Register(r) return f } From 08b09e4906f40afb830775168bee86f110910500 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 18:58:00 -0800 Subject: [PATCH 016/141] stream out interrupt chunks --- go/ai/x/agent_flow.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index 90ce23fc4b..190e68f75a 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -345,6 +345,15 @@ func DefinePromptAgent[State, PromptIn any]( sess.AddMessages(modelResp.Message) } + // If generation was interrupted, stream the interrupted message + // so the client can see the tool request parts with interrupt metadata. + if modelResp.FinishReason == ai.FinishReasonInterrupted && modelResp.Message != nil { + resp.SendChunk(&ai.ModelResponseChunk{ + Content: modelResp.Message.Content, + Role: modelResp.Message.Role, + }) + } + return nil }) } From c2f55ab55a09f5b1967360dc8673370b434a480d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 17 Feb 2026 19:01:51 -0800 Subject: [PATCH 017/141] Update genkit.go --- go/genkit/genkit.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 77b182e1ad..ad08089f25 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -463,7 +463,7 @@ func DefinePromptAgent[State, PromptIn any]( p aix.PromptRenderer[PromptIn], defaultInput PromptIn, opts ...aix.AgentFlowOption[State], -) *aix.AgentFlow[struct{}, State] { +) *aix.AgentFlow[any, State] { return aix.DefinePromptAgent(g.reg, name, p, defaultInput, opts...) } From 30b4afd54994a0781c75fb318c2bb4dd69d26833 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 07:32:40 -0800 Subject: [PATCH 018/141] Update agent_flow.go --- go/ai/x/agent_flow.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index 190e68f75a..c8d3660800 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -63,9 +63,6 @@ type AgentFlowInit[State any] struct { // State provides direct state for the invocation. // Mutually exclusive with SnapshotID. State *SessionState[State] `json:"state,omitempty"` - // InputVariables overrides the default input variables for this invocation. - // Used by agent flows that require input variables (DefinePromptAgent). - InputVariables any `json:"inputVariables,omitempty"` } // AgentFlowOutput is the output when an agent flow invocation completes. @@ -374,9 +371,8 @@ func (af *AgentFlow[Stream, State]) StreamBidi( } init := &AgentFlowInit[State]{ - SnapshotID: sbOpts.snapshotID, - State: sbOpts.state, - InputVariables: sbOpts.promptInput, + SnapshotID: sbOpts.snapshotID, + State: sbOpts.state, } conn, err := af.flow.StreamBidi(ctx, init) @@ -501,9 +497,6 @@ func newSessionFromInit[State any]( } else if init.State != nil { s.state = *init.State } - if init.InputVariables != nil { - s.state.InputVariables = init.InputVariables - } } return s, snapshot, nil From 22be8140c9859f0037b23c974002df5e5a5b6496 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 08:45:32 -0800 Subject: [PATCH 019/141] removed get snapshot action --- go/ai/x/agent_flow.go | 27 --------------------------- go/ai/x/agent_flow_test.go | 19 ------------------- go/core/api/action.go | 1 - 3 files changed, 47 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index c8d3660800..df3e1065bd 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -33,8 +33,6 @@ import ( "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/tracing" "github.com/google/uuid" - "go.opentelemetry.io/otel/attribute" - oteltrace "go.opentelemetry.io/otel/trace" ) // AgentArtifact represents a named collection of parts produced during a session. @@ -199,12 +197,6 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE a.lastSnapshot = snapshot - // Record on OTel span. - span := oteltrace.SpanFromContext(ctx) - span.SetAttributes( - attribute.String("genkit:metadata:snapshotId", snapshot.SnapshotID), - ) - return snapshot.SnapshotID } @@ -276,11 +268,6 @@ func DefineCustomAgent[Stream, State any]( af.flow = core.DefineBidiFlow(r, name, bidiFn) - // Register snapshot store action for reflection API. - if afOpts.store != nil { - registerSessionStoreAction(r, name, afOpts.store) - } - return af } @@ -502,20 +489,6 @@ func newSessionFromInit[State any]( return s, snapshot, nil } -// --- Snapshot store reflection action --- - -type getSnapshotInput struct { - SnapshotID string `json:"snapshotId"` -} - -func registerSessionStoreAction[State any](r api.Registry, flowName string, store SessionStore[State]) { - core.DefineAction(r, flowName+"/getSnapshot", api.ActionTypeSessionStore, nil, nil, - func(ctx context.Context, input getSnapshotInput) (*SessionSnapshot[State], error) { - return store.GetSnapshot(ctx, input.SnapshotID) - }, - ) -} - // --- AgentFlowConnection --- // AgentFlowConnection wraps BidiConnection with agent flow-specific functionality. diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index 5238a54cb6..804427ebdc 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -863,25 +863,6 @@ func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { } } -func TestAgentFlow_SessionStoreReflectionAction(t *testing.T) { - _ = context.Background() - reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() - - DefineCustomAgent(reg, "reflectFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return nil - }, - WithSessionStore(store), - ) - - // The getSnapshot action should be registered. - action := reg.LookupAction("/session-store/reflectFlow/getSnapshot") - if action == nil { - t.Fatal("expected getSnapshot action to be registered") - } -} - // setupPromptTestRegistry creates a registry with an echo model and generate action. func setupPromptTestRegistry(t *testing.T) *registry.Registry { t.Helper() diff --git a/go/core/api/action.go b/go/core/api/action.go index a243474bee..704fb1b9f0 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -65,7 +65,6 @@ const ( ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" ActionTypeAgentFlow ActionType = "agent-flow" - ActionTypeSessionStore ActionType = "session-store" ) // ActionDesc is a descriptor of an action. From a009a1b7a3f89091421fd2fddd57aedc2169b072 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 09:34:50 -0800 Subject: [PATCH 020/141] tagged prompt messages and excluded them --- go/ai/x/agent_flow.go | 42 +++++++++--- go/ai/x/agent_flow_test.go | 131 ++++++++++++++++++++++++++++++++++--- go/ai/x/session.go | 10 +-- 3 files changed, 161 insertions(+), 22 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index df3e1065bd..7215f222ef 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -35,9 +35,9 @@ import ( "github.com/google/uuid" ) -// AgentArtifact represents a named collection of parts produced during a session. +// Artifact represents a named collection of parts produced during a session. // Examples: generated files, images, code snippets, diagrams, etc. -type AgentArtifact struct { +type Artifact struct { // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). Name string `json:"name,omitempty"` // Parts contains the artifact content (text, media, etc.). @@ -81,7 +81,7 @@ type AgentFlowStreamChunk[Stream any] struct { // The Stream type parameter defines the shape of this data. Status Stream `json:"status,omitempty"` // Artifact contains a newly produced artifact. - Artifact *AgentArtifact `json:"artifact,omitempty"` + Artifact *Artifact `json:"artifact,omitempty"` // SnapshotID contains the ID of a snapshot that was just persisted. SnapshotID string `json:"snapshotId,omitempty"` // EndTurn signals that the agent flow has finished processing the current input. @@ -219,7 +219,7 @@ func (r Responder[Stream]) SendStatus(status Stream) { // SendArtifact sends an artifact to the stream and adds it to the session. // If an artifact with the same name already exists in the session, it is replaced. -func (r Responder[Stream]) SendArtifact(artifact *AgentArtifact) { +func (r Responder[Stream]) SendArtifact(artifact *Artifact) { r <- &AgentFlowStreamChunk[Stream]{Artifact: artifact} } @@ -278,6 +278,10 @@ type PromptRenderer[In any] interface { Render(ctx context.Context, input In) (*ai.GenerateActionOptions, error) } +// promptMessageKey is the metadata key used to tag prompt-rendered messages +// so they can be excluded from session history after generation. +const promptMessageKey = "_genkit_prompt" + // DefinePromptAgent creates a prompt-backed AgentFlow with an // automatic conversation loop. Each turn renders the prompt, appends // conversation history, calls GenerateWithRequest, streams chunks to the @@ -305,16 +309,25 @@ func DefinePromptAgent[State, PromptIn any]( } // Render the prompt template. - actionOpts, err := p.Render(ctx, promptInput) + genOpts, err := p.Render(ctx, promptInput) if err != nil { return fmt.Errorf("prompt render: %w", err) } + // Tag prompt-rendered messages so we can exclude them from + // session history after generation. + for _, m := range genOpts.Messages { + if m.Metadata == nil { + m.Metadata = make(map[string]any) + } + m.Metadata[promptMessageKey] = true + } + // Append conversation history after the prompt-rendered messages. - actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) + genOpts.Messages = append(genOpts.Messages, sess.Messages()...) // Call the model with streaming. - modelResp, err := ai.GenerateWithRequest(ctx, r, actionOpts, nil, + modelResp, err := ai.GenerateWithRequest(ctx, r, genOpts, nil, func(ctx context.Context, chunk *ai.ModelResponseChunk) error { resp.SendChunk(chunk) return nil @@ -324,8 +337,19 @@ func DefinePromptAgent[State, PromptIn any]( return fmt.Errorf("generate: %w", err) } - // Add the model response message to session history. - if modelResp.Message != nil { + // Replace session messages with the full history minus prompt + // messages. This captures intermediate tool call/response + // messages from the tool loop, not just the final response. + if modelResp.Request != nil { + var msgs []*ai.Message + for _, m := range modelResp.History() { + if m.Metadata != nil && m.Metadata[promptMessageKey] == true { + continue + } + msgs = append(msgs, m) + } + sess.SetMessages(msgs) + } else if modelResp.Message != nil { sess.AddMessages(modelResp.Message) } diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index 804427ebdc..540ac91d63 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -349,19 +349,19 @@ func TestAgentFlow_Artifacts(t *testing.T) { func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { - resp.SendArtifact(&AgentArtifact{ + resp.SendArtifact(&Artifact{ Name: "code.go", Parts: []*ai.Part{ai.NewTextPart("package main")}, }) // Replace artifact with same name. - resp.SendArtifact(&AgentArtifact{ + resp.SendArtifact(&Artifact{ Name: "code.go", Parts: []*ai.Part{ai.NewTextPart("package main\nfunc main() {}")}, }) // Add another artifact. - resp.SendArtifact(&AgentArtifact{ + resp.SendArtifact(&Artifact{ Name: "readme.md", Parts: []*ai.Part{ai.NewTextPart("# README")}, }) @@ -378,7 +378,7 @@ func TestAgentFlow_Artifacts(t *testing.T) { } conn.SendText("generate code") - var receivedArtifacts []*AgentArtifact + var receivedArtifacts []*Artifact for chunk, err := range conn.Receive() { if err != nil { t.Fatalf("Receive error: %v", err) @@ -740,7 +740,7 @@ func TestAgentFlow_TurnSpanOutput(t *testing.T) { resp.SendChunk(&ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("reply")}, }) - resp.SendArtifact(&AgentArtifact{ + resp.SendArtifact(&Artifact{ Name: "out.txt", Parts: []*ai.Part{ai.NewTextPart("content")}, }) @@ -870,7 +870,7 @@ func setupPromptTestRegistry(t *testing.T) *registry.Registry { ctx := context.Background() ai.ConfigureFormats(reg) - ai.DefineModel(reg, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + ai.DefineModel(reg, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}}, func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { // Echo back the last user message text. var text string @@ -885,6 +885,7 @@ func setupPromptTestRegistry(t *testing.T) *registry.Registry { } resp := &ai.ModelResponse{ + Request: req, Message: ai.NewModelTextMessage("echo: " + text), } @@ -1037,7 +1038,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { reg := setupPromptTestRegistry(t) // Use a model that echoes all message count so we can verify history grows. - ai.DefineModel(reg, "test/history", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + ai.DefineModel(reg, "test/history", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}}, func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { // Count total messages received (includes prompt-rendered + history). var parts []string @@ -1047,6 +1048,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { text := strings.Join(parts, "|") resp := &ai.ModelResponse{ + Request: req, Message: ai.NewModelTextMessage(text), } if cb != nil { @@ -1176,7 +1178,7 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { t.Fatalf("GetSnapshot failed: %v", err) } if snap.State.InputVariables == nil { - t.Error("expected PromptInput in snapshot state") + t.Error("expected InputVariables in snapshot state") } // Resume from snapshot — the PromptInput should be preserved. @@ -1211,3 +1213,116 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { t.Error("expected PromptInput preserved after resume") } } + +func TestPromptAgent_ToolLoopMessages(t *testing.T) { + ctx := context.Background() + reg := registry.New() + ai.ConfigureFormats(reg) + + // Define a tool that the model will call. + ai.DefineTool(reg, "greet", "returns a greeting", + func(ctx *ai.ToolContext, input struct { + Name string `json:"name"` + }) (string, error) { + return "hello " + input.Name, nil + }, + ) + + // Model that requests a tool call on the first call, then returns + // a final text response once it sees the tool result. + ai.DefineModel(reg, "test/toolmodel", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true, Tools: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Check if we already got a tool response. + for _, msg := range req.Messages { + for _, p := range msg.Content { + if p.IsToolResponse() { + resp := &ai.ModelResponse{ + Request: req, + Message: ai.NewModelTextMessage("done: " + fmt.Sprintf("%v", p.ToolResponse.Output)), + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) + } + return resp, nil + } + } + } + // First call: request the tool. + resp := &ai.ModelResponse{ + Request: req, + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "greet", + Input: map[string]any{"name": "world"}, + })}, + }, + } + return resp, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + + prompt := ai.DefinePrompt(reg, "toolPrompt", + ai.WithModelName("test/toolmodel"), + ai.WithSystem("You are a test assistant."), + ai.WithTools(ai.ToolName("greet")), + ) + + af := DefinePromptAgent[testState](reg, "toolFlow", prompt, nil) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("go") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should contain: + // 1. user message ("go") + // 2. model tool-call message + // 3. tool response message + // 4. final model text response + msgs := response.State.Messages + if got := len(msgs); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + for i, m := range msgs { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + t.FailNow() + } + + if msgs[0].Role != ai.RoleUser { + t.Errorf("msg[0] role = %s, want user", msgs[0].Role) + } + hasToolReq := false + for _, p := range msgs[1].Content { + if p.IsToolRequest() { + hasToolReq = true + break + } + } + if msgs[1].Role != ai.RoleModel || !hasToolReq { + t.Errorf("msg[1] should be a model tool-call message") + } + if msgs[2].Role != ai.RoleTool { + t.Errorf("msg[2] role = %s, want tool", msgs[2].Role) + } + if msgs[3].Role != ai.RoleModel || !strings.Contains(msgs[3].Text(), "done:") { + t.Errorf("msg[3] should be final model response, got role=%s text=%s", msgs[3].Role, msgs[3].Text()) + } +} diff --git a/go/ai/x/session.go b/go/ai/x/session.go index 17522eacdb..80df8efb98 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -35,7 +35,7 @@ type SessionState[State any] struct { // Custom is the user-defined state associated with this conversation. Custom State `json:"custom,omitempty"` // Artifacts are named collections of parts produced during the conversation. - Artifacts []*AgentArtifact `json:"artifacts,omitempty"` + Artifacts []*Artifact `json:"artifacts,omitempty"` // InputVariables is the input used for agent flows that require input variables // (e.g. prompt-backed agent flows). InputVariables any `json:"inputVariables,omitempty"` @@ -235,17 +235,17 @@ func (s *Session[State]) InputVariables() any { } // Artifacts returns the current artifacts. -func (s *Session[State]) Artifacts() []*AgentArtifact { +func (s *Session[State]) Artifacts() []*Artifact { s.mu.RLock() defer s.mu.RUnlock() - arts := make([]*AgentArtifact, len(s.state.Artifacts)) + arts := make([]*Artifact, len(s.state.Artifacts)) copy(arts, s.state.Artifacts) return arts } // AddArtifacts adds artifacts to the session. If an artifact with the same // name already exists, it is replaced. -func (s *Session[State]) AddArtifacts(artifacts ...*AgentArtifact) { +func (s *Session[State]) AddArtifacts(artifacts ...*Artifact) { s.mu.Lock() defer s.mu.Unlock() for _, a := range artifacts { @@ -266,7 +266,7 @@ func (s *Session[State]) AddArtifacts(artifacts ...*AgentArtifact) { } // SetArtifacts replaces the entire artifact list. -func (s *Session[State]) SetArtifacts(artifacts []*AgentArtifact) { +func (s *Session[State]) SetArtifacts(artifacts []*Artifact) { s.mu.Lock() defer s.mu.Unlock() s.state.Artifacts = artifacts From ea742d9cd7cbf938ec42ffdc09cb30bb92e7121a Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 09:40:31 -0800 Subject: [PATCH 021/141] fixed PromptInput -> InputVariables --- go/ai/x/agent_flow.go | 6 ++++++ go/ai/x/agent_flow_test.go | 4 ++-- go/ai/x/option.go | 6 +++--- go/genkit/genkit.go | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index 7215f222ef..346e4b09b4 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -385,6 +385,12 @@ func (af *AgentFlow[Stream, State]) StreamBidi( SnapshotID: sbOpts.snapshotID, State: sbOpts.state, } + if sbOpts.promptInput != nil { + if init.State == nil { + init.State = &SessionState[State]{} + } + init.State.InputVariables = sbOpts.promptInput + } conn, err := af.flow.StreamBidi(ctx, init) if err != nil { diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index 540ac91d63..6aea5f8e0e 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -991,7 +991,7 @@ func TestPromptAgent_PromptInputOverride(t *testing.T) { // Use WithPromptInput to override. conn, err := af.StreamBidi(ctx, - WithPromptInput[testState](greetInput{Name: "override"}), + WithInputVariables[testState](greetInput{Name: "override"}), ) if err != nil { t.Fatalf("StreamBidi failed: %v", err) @@ -1146,7 +1146,7 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { // Start with prompt input. conn, err := af.StreamBidi(ctx, - WithPromptInput[testState](map[string]any{"key": "value"}), + WithInputVariables[testState](map[string]any{"key": "value"}), ) if err != nil { t.Fatalf("StreamBidi failed: %v", err) diff --git a/go/ai/x/option.go b/go/ai/x/option.go index eaa456123b..f082068b08 100644 --- a/go/ai/x/option.go +++ b/go/ai/x/option.go @@ -110,8 +110,8 @@ func WithSnapshotID[State any](id string) StreamBidiOption[State] { return &streamBidiOptions[State]{snapshotID: id} } -// WithPromptInput overrides the default prompt input for a prompt-backed agent flow. -// Used with DefinePromptAgent to customize the prompt rendering per invocation. -func WithPromptInput[State any](input any) StreamBidiOption[State] { +// WithInputVariables overrides the default input variables for a prompt-backed agent flow. +// Used with DefinePromptAgent to customize the input variables per invocation. +func WithInputVariables[State any](input any) StreamBidiOption[State] { return &streamBidiOptions[State]{promptInput: input} } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index ad08089f25..c2ad893eef 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -452,7 +452,7 @@ func DefineCustomAgent[Stream, State any]( // conversation history, calls the model with streaming, and updates session state. // // The defaultInput is used for prompt rendering unless overridden per -// invocation via [aix.WithPromptInput]. +// invocation via [aix.WithInputVariables]. // // Type parameters: // - State: Type for user-defined state in snapshots From 2ae4bb8e616775b9e99080490a638e3b8ebef4ae Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 10:05:48 -0800 Subject: [PATCH 022/141] added AgentFlowResult to output final artifacts --- go/ai/x/agent_flow.go | 227 ++++++++++++++++++++----------------- go/ai/x/agent_flow_test.go | 111 +++++++++--------- 2 files changed, 182 insertions(+), 156 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index 346e4b09b4..f28773af5b 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -63,13 +63,28 @@ type AgentFlowInit[State any] struct { State *SessionState[State] `json:"state,omitempty"` } +// AgentFlowResult is the return value from an AgentFlowFunc. +// It contains the user-specified outputs of the agent invocation. +type AgentFlowResult struct { + // Message is the last model response message from the conversation. + Message *ai.Message `json:"message,omitempty"` + // Artifacts contains artifacts produced during the session. + Artifacts []*Artifact `json:"artifacts,omitempty"` +} + // AgentFlowOutput is the output when an agent flow invocation completes. +// It wraps AgentFlowResult with framework-managed fields. type AgentFlowOutput[State any] struct { // SnapshotID is the ID of the snapshot created at the end of this invocation. // Empty if no snapshot was created (callback returned false or no store configured). SnapshotID string `json:"snapshotId,omitempty"` + // Message is the last model response message from the conversation. + Message *ai.Message `json:"message,omitempty"` + // Artifacts contains artifacts produced during the session. + Artifacts []*Artifact `json:"artifacts,omitempty"` // State contains the final conversation state. - State *SessionState[State] `json:"state"` + // Only populated when state is client-managed (no store configured). + State *SessionState[State] `json:"state,omitempty"` } // AgentFlowStreamChunk represents a single item in the agent flow's output stream. @@ -101,8 +116,8 @@ type AgentSession[State any] struct { lastSnapshot *SessionSnapshot[State] turnIndex int // collectTurnOutput returns the accumulated stream chunks for the current - // turn and resets the accumulator. Set by runWrapped; nil if no collection - // is configured (e.g., standalone usage). + // turn and resets the accumulator. Set by DefineCustomAgent; nil if no + // collection is configured (e.g., standalone usage). collectTurnOutput func() any } @@ -229,13 +244,11 @@ func (r Responder[Stream]) SendArtifact(artifact *Artifact) { // Type parameters: // - Stream: Type for status updates sent via the responder // - State: Type for user-defined state in snapshots -type AgentFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *AgentSession[State]) error +type AgentFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *AgentSession[State]) (*AgentFlowResult, error) // AgentFlow is a bidirectional streaming flow with automatic snapshot management. type AgentFlow[Stream, State any] struct { - flow *core.Flow[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream], *AgentFlowInit[State]] - store SessionStore[State] - snapshotCallback SnapshotCallback[State] + flow *core.Flow[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream], *AgentFlowInit[State]] } // DefineCustomAgent creates an AgentFlow with automatic snapshot management and registers it. @@ -252,23 +265,104 @@ func DefineCustomAgent[Stream, State any]( } } - af := &AgentFlow[Stream, State]{ - store: afOpts.store, - snapshotCallback: afOpts.callback, - } + store := afOpts.store + snapshotCallback := afOpts.callback - bidiFn := func( + flow := core.DefineBidiFlow(r, name, func( ctx context.Context, init *AgentFlowInit[State], inCh <-chan *AgentFlowInput, outCh chan<- *AgentFlowStreamChunk[Stream], ) (*AgentFlowOutput[State], error) { - return af.runWrapped(ctx, init, inCh, outCh, fn) - } + session, snapshot, err := newSessionFromInit(ctx, init, store) + if err != nil { + return nil, err + } + ctx = NewSessionContext(ctx, session) - af.flow = core.DefineBidiFlow(r, name, bidiFn) + agentSess := &AgentSession[State]{ + Session: session, + snapshotCallback: snapshotCallback, + inCh: inCh, + lastSnapshot: snapshot, + } + if snapshot != nil { + agentSess.turnIndex = snapshot.TurnIndex + } + + // Turn output accumulator: collects content chunks per turn for span output. + var ( + turnMu sync.Mutex + turnChunks []*AgentFlowStreamChunk[Stream] + ) - return af + agentSess.collectTurnOutput = func() any { + turnMu.Lock() + defer turnMu.Unlock() + result := turnChunks + turnChunks = nil + return result + } + + // Intermediary channel: intercepts artifacts, accumulates turn output, + // and forwards to outCh. + respCh := make(chan *AgentFlowStreamChunk[Stream]) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for chunk := range respCh { + if chunk.Artifact != nil { + session.AddArtifacts(chunk.Artifact) + } + // Accumulate content chunks (exclude control signals from onEndTurn). + if !chunk.EndTurn && chunk.SnapshotID == "" { + turnMu.Lock() + turnChunks = append(turnChunks, chunk) + turnMu.Unlock() + } + outCh <- chunk + } + }() + + // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. + // Writes through respCh to preserve ordering with user chunks. + agentSess.onEndTurn = func(turnCtx context.Context) { + snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) + if snapshotID != "" { + respCh <- &AgentFlowStreamChunk[Stream]{SnapshotID: snapshotID} + } + respCh <- &AgentFlowStreamChunk[Stream]{EndTurn: true} + } + + result, fnErr := fn(ctx, Responder[Stream](respCh), agentSess) + close(respCh) + wg.Wait() + + if fnErr != nil { + return nil, fnErr + } + + // Final snapshot at invocation end. + snapshotID := agentSess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + + out := &AgentFlowOutput[State]{ + SnapshotID: snapshotID, + } + if result != nil { + out.Message = result.Message + out.Artifacts = result.Artifacts + } + + // Only include full state when client-managed (no store). + if store == nil { + out.State = session.State() + } + + return out, nil + }) + + return &AgentFlow[Stream, State]{flow: flow} } // PromptRenderer renders a prompt with typed input into GenerateActionOptions. @@ -296,8 +390,9 @@ func DefinePromptAgent[State, PromptIn any]( defaultInput PromptIn, opts ...AgentFlowOption[State], ) *AgentFlow[any, State] { - fn := func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + fn := func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) (*AgentFlowResult, error) { + var lastModelMessage *ai.Message + err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Resolve prompt input: session state override > default. promptInput := defaultInput if stored := sess.InputVariables(); stored != nil { @@ -337,6 +432,8 @@ func DefinePromptAgent[State, PromptIn any]( return fmt.Errorf("generate: %w", err) } + lastModelMessage = modelResp.Message + // Replace session messages with the full history minus prompt // messages. This captures intermediate tool call/response // messages from the tool loop, not just the final response. @@ -364,6 +461,13 @@ func DefinePromptAgent[State, PromptIn any]( return nil }) + if err != nil { + return nil, err + } + return &AgentFlowResult{ + Message: lastModelMessage, + Artifacts: sess.Artifacts(), + }, nil } return DefineCustomAgent(r, name, fn, opts...) @@ -400,93 +504,6 @@ func (af *AgentFlow[Stream, State]) StreamBidi( return &AgentFlowConnection[Stream, State]{conn: conn}, nil } -// runWrapped is the BidiFunc implementation. It sets up the session, -// responder, and wiring, then delegates to the user's function. -func (af *AgentFlow[Stream, State]) runWrapped( - ctx context.Context, - init *AgentFlowInit[State], - inCh <-chan *AgentFlowInput, - outCh chan<- *AgentFlowStreamChunk[Stream], - fn AgentFlowFunc[Stream, State], -) (*AgentFlowOutput[State], error) { - session, snapshot, err := newSessionFromInit(ctx, init, af.store) - if err != nil { - return nil, err - } - ctx = NewSessionContext(ctx, session) - - agentSess := &AgentSession[State]{ - Session: session, - snapshotCallback: af.snapshotCallback, - inCh: inCh, - lastSnapshot: snapshot, - } - if snapshot != nil { - agentSess.turnIndex = snapshot.TurnIndex - } - - // Turn output accumulator: collects content chunks per turn for span output. - var ( - turnMu sync.Mutex - turnChunks []*AgentFlowStreamChunk[Stream] - ) - - agentSess.collectTurnOutput = func() any { - turnMu.Lock() - defer turnMu.Unlock() - result := turnChunks - turnChunks = nil - return result - } - - // Intermediary channel: intercepts artifacts, accumulates turn output, - // and forwards to outCh. - respCh := make(chan *AgentFlowStreamChunk[Stream]) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - for chunk := range respCh { - if chunk.Artifact != nil { - session.AddArtifacts(chunk.Artifact) - } - // Accumulate content chunks (exclude control signals from onEndTurn). - if !chunk.EndTurn && chunk.SnapshotID == "" { - turnMu.Lock() - turnChunks = append(turnChunks, chunk) - turnMu.Unlock() - } - outCh <- chunk - } - }() - - // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. - // Writes through respCh to preserve ordering with user chunks. - agentSess.onEndTurn = func(turnCtx context.Context) { - snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) - if snapshotID != "" { - respCh <- &AgentFlowStreamChunk[Stream]{SnapshotID: snapshotID} - } - respCh <- &AgentFlowStreamChunk[Stream]{EndTurn: true} - } - - fnErr := fn(ctx, Responder[Stream](respCh), agentSess) - close(respCh) - wg.Wait() - - if fnErr != nil { - return nil, fnErr - } - - // Final snapshot at invocation end. - snapshotID := agentSess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) - - return &AgentFlowOutput[State]{ - State: session.State(), - SnapshotID: snapshotID, - }, nil -} - // newSessionFromInit creates a Session from initialization data. // If resuming from a snapshot, the loaded snapshot is also returned. func newSessionFromInit[State any]( diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index 6aea5f8e0e..92d287be78 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -45,8 +45,8 @@ func TestAgentFlow_BasicMultiTurn(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "basicFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. if len(input.Messages) > 0 { @@ -121,8 +121,8 @@ func TestAgentFlow_WithSessionStore(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "snapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -194,8 +194,8 @@ func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "resumeFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -252,16 +252,6 @@ func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // Should have messages from both invocations: - // first: user + reply (2) + second: user + reply (2) = 4. - if got := len(resp2.State.Messages); got != 4 { - t.Errorf("expected 4 messages after resume, got %d", got) - } - // Counter should be 2 (1 from first + 1 from second). - if got := resp2.State.Custom.Counter; got != 2 { - t.Errorf("expected counter=2, got %d", got) - } - // The new snapshot should reference the previous as parent. if resp2.SnapshotID == "" { t.Fatal("expected snapshot ID from second invocation") @@ -270,6 +260,16 @@ func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { if err != nil { t.Fatalf("GetSnapshot failed: %v", err) } + + // Should have messages from both invocations: + // first: user + reply (2) + second: user + reply (2) = 4. + if got := len(snap2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + // Counter should be 2 (1 from first + 1 from second). + if got := snap2.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } // The parent chain: snap2's parent is a turn-end snapshot from the second invocation, // which itself has a parent from the first invocation's final snapshot. // We just verify that the parent chain exists (not empty). @@ -283,8 +283,8 @@ func TestAgentFlow_ClientManagedState(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "clientStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -346,8 +346,8 @@ func TestAgentFlow_Artifacts(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "artifactFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { resp.SendArtifact(&Artifact{ Name: "code.go", @@ -369,6 +369,10 @@ func TestAgentFlow_Artifacts(t *testing.T) { sess.AddMessages(ai.NewModelTextMessage("done")) return nil }) + if err != nil { + return nil, err + } + return &AgentFlowResult{Artifacts: sess.Artifacts()}, nil }, ) @@ -401,9 +405,9 @@ func TestAgentFlow_Artifacts(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // Session should have 2 unique artifacts (code.go was replaced). - if got := len(response.State.Artifacts); got != 2 { - t.Errorf("expected 2 artifacts in state, got %d", got) + // Output should have 2 unique artifacts (code.go was replaced). + if got := len(response.Artifacts); got != 2 { + t.Errorf("expected 2 artifacts, got %d", got) } } @@ -415,8 +419,8 @@ func TestAgentFlow_SnapshotCallback(t *testing.T) { // Only snapshot on even turns. callbackCalls := 0 af := DefineCustomAgent(reg, "callbackFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -471,8 +475,8 @@ func TestAgentFlow_SendMessages(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "sendMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { return nil }) }, @@ -518,8 +522,8 @@ func TestAgentFlow_SessionContext(t *testing.T) { var retrievedCounter int af := DefineCustomAgent(reg, "contextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Session should be retrievable from context. ctxSess := SessionFromContext[testState](ctx) if ctxSess == nil { @@ -563,8 +567,8 @@ func TestAgentFlow_ErrorInTurn(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "errorFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { return fmt.Errorf("turn failed") }) }, @@ -589,8 +593,8 @@ func TestAgentFlow_SetMessages(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "setMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Replace all messages with just one. sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) return nil @@ -631,11 +635,16 @@ func TestAgentFlow_SnapshotIDInMessageMetadata(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "metadataFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil }) + if err != nil { + return nil, err + } + msgs := sess.Messages() + return &AgentFlowResult{Message: msgs[len(msgs)-1]}, nil }, WithSessionStore(store), ) @@ -661,16 +670,14 @@ func TestAgentFlow_SnapshotIDInMessageMetadata(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // The last message should have snapshotId in its metadata. - msgs := response.State.Messages - if len(msgs) == 0 { - t.Fatal("expected messages in response") + // The last model message should have snapshotId in its metadata. + if response.Message == nil { + t.Fatal("expected Message in response") } - lastMsg := msgs[len(msgs)-1] - if lastMsg.Metadata == nil { + if response.Message.Metadata == nil { t.Fatal("expected metadata on last message") } - if _, ok := lastMsg.Metadata["snapshotId"]; !ok { + if _, ok := response.Message.Metadata["snapshotId"]; !ok { t.Error("expected snapshotId in last message metadata") } } @@ -726,7 +733,7 @@ func TestAgentFlow_TurnSpanOutput(t *testing.T) { var capturedOutputs []any af := DefineCustomAgent(reg, "turnOutputFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { // Wrap collectTurnOutput to capture what each turn produces. originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { @@ -735,7 +742,7 @@ func TestAgentFlow_TurnSpanOutput(t *testing.T) { return output } - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { resp.SendStatus(testStatus{Phase: "thinking"}) resp.SendChunk(&ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("reply")}, @@ -808,7 +815,7 @@ func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { var capturedOutputs []any af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) error { + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { output := originalCollect() @@ -816,7 +823,7 @@ func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { return output } - return sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { resp.SendStatus(testStatus{Phase: "working"}) sess.AddMessages(ai.NewModelTextMessage("reply")) return nil @@ -1203,13 +1210,15 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // Should have messages from both invocations. - if got := len(resp2.State.Messages); got != 4 { + // Verify state via snapshot (server-managed state). + snap2, err := store.GetSnapshot(ctx, resp2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if got := len(snap2.State.Messages); got != 4 { t.Errorf("expected 4 messages after resume, got %d", got) } - - // PromptInput should still be present. - if resp2.State.InputVariables == nil { + if snap2.State.InputVariables == nil { t.Error("expected PromptInput preserved after resume") } } From 787f61aae5eb17a0da9c5c27fb6f8533f8935cc3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 10:07:19 -0800 Subject: [PATCH 023/141] Update agent_flow.go --- go/ai/x/agent_flow.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index f28773af5b..5472d7982d 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -78,13 +78,13 @@ type AgentFlowOutput[State any] struct { // SnapshotID is the ID of the snapshot created at the end of this invocation. // Empty if no snapshot was created (callback returned false or no store configured). SnapshotID string `json:"snapshotId,omitempty"` + // State contains the final conversation state. + // Only populated when state is client-managed (no store configured). + State *SessionState[State] `json:"state,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` - // State contains the final conversation state. - // Only populated when state is client-managed (no store configured). - State *SessionState[State] `json:"state,omitempty"` } // AgentFlowStreamChunk represents a single item in the agent flow's output stream. From d1281d588aaa9e48dfd3d28a3dab2f2692214854 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 10:17:07 -0800 Subject: [PATCH 024/141] removed turn index from snapshot --- go/ai/x/agent_flow.go | 4 ---- go/ai/x/agent_flow_test.go | 4 ---- go/ai/x/session.go | 4 +--- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index 5472d7982d..eba958ffb5 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -186,7 +186,6 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE snapshot := &SessionSnapshot[State]{ SnapshotID: uuid.New().String(), CreatedAt: time.Now(), - TurnIndex: a.turnIndex, Event: event, State: currentState, } @@ -286,9 +285,6 @@ func DefineCustomAgent[Stream, State any]( inCh: inCh, lastSnapshot: snapshot, } - if snapshot != nil { - agentSess.turnIndex = snapshot.TurnIndex - } // Turn output accumulator: collects content chunks per turn for span output. var ( diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index 92d287be78..235559c888 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -171,9 +171,6 @@ func TestAgentFlow_WithSessionStore(t *testing.T) { if snap.State.Custom.Counter != 1 { t.Errorf("expected counter=1 in snapshot, got %d", snap.State.Custom.Counter) } - if snap.TurnIndex != 0 { - t.Errorf("expected turnIndex=0, got %d", snap.TurnIndex) - } conn.Close() @@ -698,7 +695,6 @@ func TestInMemorySessionStore(t *testing.T) { // Save and retrieve. snapshot := &SessionSnapshot[testState]{ SnapshotID: "snap-1", - TurnIndex: 0, State: SessionState[testState]{ Custom: testState{Counter: 1}, }, diff --git a/go/ai/x/session.go b/go/ai/x/session.go index 80df8efb98..9b59036f31 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -59,8 +59,6 @@ type SessionSnapshot[State any] struct { ParentID string `json:"parentId,omitempty"` // CreatedAt is when the snapshot was created. CreatedAt time.Time `json:"createdAt"` - // TurnIndex is the turn number when this snapshot was created (0-indexed). - TurnIndex int `json:"turnIndex"` // Event is what triggered this snapshot. Event SnapshotEvent `json:"event"` // State is the actual conversation state. @@ -73,7 +71,7 @@ type SnapshotContext[State any] struct { State *SessionState[State] // PrevState is the state at the last snapshot, or nil if none exists. PrevState *SessionState[State] - // TurnIndex is the current turn number. + // TurnIndex is the turn number in the current invocation. TurnIndex int // Event is what triggered this snapshot check. Event SnapshotEvent From d6cae449d780812cebf5ff71eb2d8dbc09ede510 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 10:28:14 -0800 Subject: [PATCH 025/141] exposed InputCh and TurnIndex on AgentSession --- go/ai/x/agent_flow.go | 39 +++++++++++++++++++-------------- go/ai/x/agent_flow_test.go | 10 ++++----- go/samples/prompt-agent/main.go | 4 ++-- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index eba958ffb5..a2dd70e66e 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -90,8 +90,8 @@ type AgentFlowOutput[State any] struct { // AgentFlowStreamChunk represents a single item in the agent flow's output stream. // Multiple fields can be populated in a single chunk. type AgentFlowStreamChunk[Stream any] struct { - // Chunk contains token-level generation data. - Chunk *ai.ModelResponseChunk `json:"chunk,omitempty"` + // ModelChunk contains generation tokens from the model. + ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` // Status contains user-defined structured status information. // The Stream type parameter defines the shape of this data. Status Stream `json:"status,omitempty"` @@ -110,14 +110,21 @@ type AgentFlowStreamChunk[Stream any] struct { // turn management, snapshot persistence, and input channel handling. type AgentSession[State any] struct { *Session[State] - snapshotCallback SnapshotCallback[State] - onEndTurn func(ctx context.Context) - inCh <-chan *AgentFlowInput - lastSnapshot *SessionSnapshot[State] - turnIndex int - // collectTurnOutput returns the accumulated stream chunks for the current - // turn and resets the accumulator. Set by DefineCustomAgent; nil if no - // collection is configured (e.g., standalone usage). + + // InputCh is the channel that delivers per-turn inputs from the client. + // It is consumed automatically by [AgentSession.Run], but is exposed + // for advanced use cases that need direct access to the input stream + // (e.g., custom turn loops or fan-out patterns). + InputCh <-chan *AgentFlowInput + // TurnIndex is the zero-based index of the current conversation turn. + // It is incremented automatically by [AgentSession.Run], but is exposed + // for advanced use cases that need to track or manipulate turn ordering + // directly. + TurnIndex int + + snapshotCallback SnapshotCallback[State] + onEndTurn func(ctx context.Context) + lastSnapshot *SessionSnapshot[State] collectTurnOutput func() any } @@ -126,9 +133,9 @@ type AgentSession[State any] struct { // added to the session before fn is called. After fn returns successfully, an // EndTurn chunk is sent and a snapshot check is triggered. func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentFlowInput) error) error { - for input := range a.inCh { + for input := range a.InputCh { spanMeta := &tracing.SpanMetadata{ - Name: fmt.Sprintf("agentFlow/turn/%d", a.turnIndex), + Name: fmt.Sprintf("agentFlow/turn/%d", a.TurnIndex), Type: "agentFlowTurn", Subtype: "agentFlowTurn", } @@ -142,7 +149,7 @@ func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Conte } a.onEndTurn(ctx) - a.turnIndex++ + a.TurnIndex++ if a.collectTurnOutput != nil { return a.collectTurnOutput(), nil @@ -176,7 +183,7 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE if !a.snapshotCallback(ctx, &SnapshotContext[State]{ State: ¤tState, PrevState: prevState, - TurnIndex: a.turnIndex, + TurnIndex: a.TurnIndex, Event: event, }) { return "" @@ -223,7 +230,7 @@ type Responder[Stream any] chan<- *AgentFlowStreamChunk[Stream] // SendChunk sends a generation chunk (token-level streaming). func (r Responder[Stream]) SendChunk(chunk *ai.ModelResponseChunk) { - r <- &AgentFlowStreamChunk[Stream]{Chunk: chunk} + r <- &AgentFlowStreamChunk[Stream]{ModelChunk: chunk} } // SendStatus sends a user-defined status update. @@ -282,7 +289,7 @@ func DefineCustomAgent[Stream, State any]( agentSess := &AgentSession[State]{ Session: session, snapshotCallback: snapshotCallback, - inCh: inCh, + InputCh: inCh, lastSnapshot: snapshot, } diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index 235559c888..ad0b195c75 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -935,7 +935,7 @@ func TestPromptAgent_Basic(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.Chunk != nil { + if chunk.ModelChunk != nil { gotChunk = true } if chunk.EndTurn { @@ -1082,8 +1082,8 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.Chunk != nil { - turn1Response += chunk.Chunk.Text() + if chunk.ModelChunk != nil { + turn1Response += chunk.ModelChunk.Text() } if chunk.EndTurn { break @@ -1103,8 +1103,8 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.Chunk != nil { - turn2Response += chunk.Chunk.Text() + if chunk.ModelChunk != nil { + turn2Response += chunk.ModelChunk.Text() } if chunk.EndTurn { break diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index 95e83eed7c..dd45cc2f16 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -83,8 +83,8 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) break } - if chunk.Chunk != nil { - fmt.Print(chunk.Chunk.Text()) + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) } if chunk.SnapshotID != "" { fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) From 10bbd03201cbbb98357771a136929cb81900bfe2 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 10:38:18 -0800 Subject: [PATCH 026/141] improvements to API --- go/ai/x/agent_flow.go | 8 ++++---- go/ai/x/agent_flow_test.go | 2 +- go/ai/x/option.go | 19 ++++++++++++++++++- go/ai/x/session.go | 13 ------------- go/samples/custom-agent/main.go | 20 +++++++++++++------- 5 files changed, 36 insertions(+), 26 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index a2dd70e66e..055e170569 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -228,8 +228,8 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE // client. type Responder[Stream any] chan<- *AgentFlowStreamChunk[Stream] -// SendChunk sends a generation chunk (token-level streaming). -func (r Responder[Stream]) SendChunk(chunk *ai.ModelResponseChunk) { +// SendModelChunk sends a generation chunk (token-level streaming). +func (r Responder[Stream]) SendModelChunk(chunk *ai.ModelResponseChunk) { r <- &AgentFlowStreamChunk[Stream]{ModelChunk: chunk} } @@ -427,7 +427,7 @@ func DefinePromptAgent[State, PromptIn any]( // Call the model with streaming. modelResp, err := ai.GenerateWithRequest(ctx, r, genOpts, nil, func(ctx context.Context, chunk *ai.ModelResponseChunk) error { - resp.SendChunk(chunk) + resp.SendModelChunk(chunk) return nil }, ) @@ -456,7 +456,7 @@ func DefinePromptAgent[State, PromptIn any]( // If generation was interrupted, stream the interrupted message // so the client can see the tool request parts with interrupt metadata. if modelResp.FinishReason == ai.FinishReasonInterrupted && modelResp.Message != nil { - resp.SendChunk(&ai.ModelResponseChunk{ + resp.SendModelChunk(&ai.ModelResponseChunk{ Content: modelResp.Message.Content, Role: modelResp.Message.Role, }) diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_flow_test.go index ad0b195c75..5762713254 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_flow_test.go @@ -740,7 +740,7 @@ func TestAgentFlow_TurnSpanOutput(t *testing.T) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { resp.SendStatus(testStatus{Phase: "thinking"}) - resp.SendChunk(&ai.ModelResponseChunk{ + resp.SendModelChunk(&ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("reply")}, }) resp.SendArtifact(&Artifact{ diff --git a/go/ai/x/option.go b/go/ai/x/option.go index f082068b08..9ef6152566 100644 --- a/go/ai/x/option.go +++ b/go/ai/x/option.go @@ -16,7 +16,10 @@ package aix -import "errors" +import ( + "context" + "errors" +) // --- AgentFlowOption --- @@ -57,6 +60,20 @@ func WithSnapshotCallback[State any](cb SnapshotCallback[State]) AgentFlowOption return &agentFlowOptions[State]{callback: cb} } +// WithSnapshotOn configures snapshots to be created only for the specified events. +// For example, WithSnapshotOn[MyState](SnapshotEventTurnEnd) skips the +// invocation-end snapshot. +func WithSnapshotOn[State any](events ...SnapshotEvent) AgentFlowOption[State] { + set := make(map[SnapshotEvent]struct{}, len(events)) + for _, e := range events { + set[e] = struct{}{} + } + return WithSnapshotCallback[State](func(_ context.Context, sc *SnapshotContext[State]) bool { + _, ok := set[sc.Event] + return ok + }) +} + // --- StreamBidiOption --- // StreamBidiOption configures a StreamBidi call. diff --git a/go/ai/x/session.go b/go/ai/x/session.go index 9b59036f31..d30685a61d 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -148,19 +148,6 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta return &copied, nil } -// SnapshotOn returns a SnapshotCallback that only allows snapshots for the -// specified events. For example, SnapshotOn[MyState](TurnEnd) will skip the -// invocation-end snapshot. -func SnapshotOn[State any](events ...SnapshotEvent) SnapshotCallback[State] { - set := make(map[SnapshotEvent]struct{}, len(events)) - for _, e := range events { - set[e] = struct{}{} - } - return func(_ context.Context, sc *SnapshotContext[State]) bool { - _, ok := set[sc.Event] - return ok - } -} // --- Session --- diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index b7b73c51dc..0fdde46871 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -36,8 +36,9 @@ func main() { g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) chatFlow := genkit.DefineCustomAgent(g, "chat", - func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) error { - return sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { + func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentFlowResult, error) { + var lastMessage *ai.Message + err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ @@ -51,17 +52,22 @@ func main() { return err } if chunk.Done { - sess.AddMessages(chunk.Response.Message) + lastMessage = chunk.Response.Message + sess.AddMessages(lastMessage) break } - resp.SendChunk(chunk.Chunk) + resp.SendModelChunk(chunk.Chunk) } return nil }) + if err != nil { + return nil, err + } + return &aix.AgentFlowResult{Message: lastMessage}, nil }, aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), - aix.WithSnapshotCallback(aix.SnapshotOn[any](aix.SnapshotEventTurnEnd)), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) fmt.Println("Agent Flow Chat (type 'quit' to exit)") @@ -98,8 +104,8 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) break } - if chunk.Chunk != nil { - fmt.Print(chunk.Chunk.Text()) + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) } if chunk.SnapshotID != "" { fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) From a687eb6763f6450b7419a5e20dcdff743b443a3d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 10:54:42 -0800 Subject: [PATCH 027/141] minor fixes --- go/ai/x/agent_flow.go | 9 ++++++--- go/ai/x/session.go | 4 ++-- go/core/x/session/session.go | 8 ++++---- go/core/x/session/session_test.go | 4 ++-- go/plugins/googlegenai/googleai_live_test.go | 2 +- 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index 055e170569..dfe6d4a75a 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -24,13 +24,13 @@ import ( "context" "fmt" "iter" - "log/slog" "sync" "time" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/google/uuid" ) @@ -201,7 +201,7 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE } if err := a.store.SaveSnapshot(ctx, snapshot); err != nil { - slog.Error("agent flow: failed to save snapshot", "err", err) + logger.FromContext(ctx).Error("agent flow: failed to save snapshot", "err", err) return "" } @@ -401,7 +401,7 @@ func DefinePromptAgent[State, PromptIn any]( if stored := sess.InputVariables(); stored != nil { typed, ok := stored.(PromptIn) if !ok { - return fmt.Errorf("prompt input type mismatch: got %T, want %T", stored, promptInput) + return core.NewError(core.INVALID_ARGUMENT, "prompt input type mismatch: got %T, want %T", stored, promptInput) } promptInput = typed } @@ -518,6 +518,9 @@ func newSessionFromInit[State any]( var snapshot *SessionSnapshot[State] if init != nil { + if init.SnapshotID != "" && init.State != nil { + return nil, nil, core.NewError(core.INVALID_ARGUMENT, "snapshotId and state are mutually exclusive") + } if init.SnapshotID != "" && store == nil { return nil, nil, core.NewError(core.FAILED_PRECONDITION, "snapshot ID %q provided but no session store configured", init.SnapshotID) } diff --git a/go/ai/x/session.go b/go/ai/x/session.go index d30685a61d..c516e5ac05 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -139,11 +139,11 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta } bytes, err := json.Marshal(snap) if err != nil { - return nil, err + return nil, fmt.Errorf("copy snapshot: marshal: %w", err) } var copied SessionSnapshot[State] if err := json.Unmarshal(bytes, &copied); err != nil { - return nil, err + return nil, fmt.Errorf("copy snapshot: unmarshal: %w", err) } return &copied, nil } diff --git a/go/core/x/session/session.go b/go/core/x/session/session.go index 8a9f0387f9..01db9e8562 100644 --- a/go/core/x/session/session.go +++ b/go/core/x/session/session.go @@ -165,7 +165,7 @@ func New[S any](ctx context.Context, opts ...Option[S]) (*Session[S], error) { func Load[S any](ctx context.Context, store Store[S], sessionID string) (*Session[S], error) { data, err := store.Get(ctx, sessionID) if err != nil { - return nil, err + return nil, fmt.Errorf("session.Load: %w", err) } if data == nil { return nil, &NotFoundError{SessionID: sessionID} @@ -221,7 +221,7 @@ func (s *Session[S]) UpdateState(ctx context.Context, state S) error { State: state, } if err := s.store.Save(ctx, s.id, data); err != nil { - return err + return fmt.Errorf("session.UpdateState: %w", err) } } @@ -346,12 +346,12 @@ func copyData[S any](data *Data[S]) (*Data[S], error) { bytes, err := json.Marshal(data) if err != nil { - return nil, err + return nil, fmt.Errorf("copy session data: marshal: %w", err) } var copied Data[S] if err := json.Unmarshal(bytes, &copied); err != nil { - return nil, err + return nil, fmt.Errorf("copy session data: unmarshal: %w", err) } return &copied, nil diff --git a/go/core/x/session/session_test.go b/go/core/x/session/session_test.go index b34100a44c..55c25d76d1 100644 --- a/go/core/x/session/session_test.go +++ b/go/core/x/session/session_test.go @@ -776,7 +776,7 @@ func TestSession_UpdateState_StoreError(t *testing.T) { if err == nil { t.Fatal("Expected error from failing store") } - if err != expectedErr { - t.Errorf("Expected error %v, got %v", expectedErr, err) + if !errors.Is(err, expectedErr) { + t.Errorf("Expected error wrapping %v, got %v", expectedErr, err) } } diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index 783eccd239..f65b42ff56 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -70,7 +70,7 @@ func TestGoogleAILive(t *testing.T) { genkit.WithPlugins(&googlegenai.GoogleAI{APIKey: apiKey}), ) - embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001") + embedder := googlegenai.GoogleAIEmbedder(g, "gemini-embedding-001") gablorkenTool := genkit.DefineTool(g, "gablorken", "use this tool when the user asks to calculate a gablorken, carefuly inspect the user input to determine which value from the prompt corresponds to the input structure", func(ctx *ai.ToolContext, input struct { From 7c535c465543edd3d6c9eb9684918289b00a7116 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 10:57:27 -0800 Subject: [PATCH 028/141] Update session.go --- go/ai/x/session.go | 1 - 1 file changed, 1 deletion(-) diff --git a/go/ai/x/session.go b/go/ai/x/session.go index c516e5ac05..6df97a7c67 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -148,7 +148,6 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta return &copied, nil } - // --- Session --- // Session holds conversation state and provides thread-safe read/write access to messages, From 26f8f7d10e1e59d248689d3644830f94df42732f Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 11:30:32 -0800 Subject: [PATCH 029/141] added shared schemas for agent types --- genkit-tools/common/src/types/agent.ts | 118 ++++++++++ genkit-tools/genkit-schema.json | 134 +++++++++++ genkit-tools/scripts/schema-exporter.ts | 1 + go/ai/generate.go | 21 ++ go/ai/x/agent_flow.go | 86 ++----- go/ai/x/gen.go | 122 ++++++++++ go/ai/x/session.go | 25 -- go/core/schemas.config | 219 ++++++++++++++++++ .../cmd/jsonschemagen/jsonschemagen.go | 51 +++- 9 files changed, 671 insertions(+), 106 deletions(-) create mode 100644 genkit-tools/common/src/types/agent.ts create mode 100644 go/ai/x/gen.go diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts new file mode 100644 index 0000000000..fb8c2ce9ec --- /dev/null +++ b/genkit-tools/common/src/types/agent.ts @@ -0,0 +1,118 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from 'zod'; +import { MessageSchema, ModelResponseChunkSchema } from './model'; +import { PartSchema } from './parts'; + +/** + * Zod schema for an artifact produced during a session. + */ +export const ArtifactSchema = z.object({ + /** Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). */ + name: z.string().optional(), + /** Parts contains the artifact content (text, media, etc.). */ + parts: z.array(PartSchema), + /** Metadata contains additional artifact-specific data. */ + metadata: z.record(z.any()).optional(), +}); +export type Artifact = z.infer; + +/** + * Zod schema for snapshot event. + */ +export const SnapshotEventSchema = z.enum(['turnEnd', 'invocationEnd']); +export type SnapshotEvent = z.infer; + +/** + * Zod schema for session state. + */ +export const SessionStateSchema = z.object({ + /** Conversation history (user/model exchanges). */ + messages: z.array(MessageSchema).optional(), + /** User-defined state associated with this conversation. */ + custom: z.any().optional(), + /** Named collections of parts produced during the conversation. */ + artifacts: z.array(ArtifactSchema).optional(), + /** Input used for agent flows that require input variables. */ + inputVariables: z.any().optional(), +}); +export type SessionState = z.infer; + +/** + * Zod schema for agent flow input (per-turn). + */ +export const AgentFlowInputSchema = z.object({ + /** User's input messages for this turn. */ + messages: z.array(MessageSchema).optional(), + /** Tool request parts to re-execute interrupted tools. */ + toolRestarts: z.array(PartSchema).optional(), +}); +export type AgentFlowInput = z.infer; + +/** + * Zod schema for agent flow initialization. + */ +export const AgentFlowInitSchema = z.object({ + /** Loads state from a persisted snapshot. Mutually exclusive with state. */ + snapshotId: z.string().optional(), + /** Direct state for the invocation. Mutually exclusive with snapshotId. */ + state: SessionStateSchema.optional(), +}); +export type AgentFlowInit = z.infer; + +/** + * Zod schema for agent flow result. + */ +export const AgentFlowResultSchema = z.object({ + /** Last model response message from the conversation. */ + message: MessageSchema.optional(), + /** Artifacts produced during the session. */ + artifacts: z.array(ArtifactSchema).optional(), +}); +export type AgentFlowResult = z.infer; + +/** + * Zod schema for agent flow output. + */ +export const AgentFlowOutputSchema = z.object({ + /** ID of the snapshot created at the end of this invocation. */ + snapshotId: z.string().optional(), + /** Final conversation state (only when client-managed). */ + state: SessionStateSchema.optional(), + /** Last model response message from the conversation. */ + message: MessageSchema.optional(), + /** Artifacts produced during the session. */ + artifacts: z.array(ArtifactSchema).optional(), +}); +export type AgentFlowOutput = z.infer; + +/** + * Zod schema for agent flow stream chunk. + */ +export const AgentFlowStreamChunkSchema = z.object({ + /** Generation tokens from the model. */ + modelChunk: ModelResponseChunkSchema.optional(), + /** User-defined structured status information. */ + status: z.any().optional(), + /** A newly produced artifact. */ + artifact: ArtifactSchema.optional(), + /** ID of a snapshot that was just persisted. */ + snapshotId: z.string().optional(), + /** Signals that the agent flow has finished processing the current input. */ + endTurn: z.boolean().optional(), +}); +export type AgentFlowStreamChunk = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 26cc4fbf4f..bae17a88ae 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -1,6 +1,140 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "$defs": { + "AgentFlowInit": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + } + }, + "additionalProperties": false + }, + "AgentFlowInput": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "$ref": "#/$defs/Message" + } + }, + "toolRestarts": { + "type": "array", + "items": { + "$ref": "#/$defs/Part" + } + } + }, + "additionalProperties": false + }, + "AgentFlowOutput": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + }, + "message": { + "$ref": "#/$defs/Message" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + } + }, + "additionalProperties": false + }, + "AgentFlowResult": { + "type": "object", + "properties": { + "message": { + "$ref": "#/$defs/Message" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + } + }, + "additionalProperties": false + }, + "AgentFlowStreamChunk": { + "type": "object", + "properties": { + "modelChunk": { + "$ref": "#/$defs/ModelResponseChunk" + }, + "status": {}, + "artifact": { + "$ref": "#/$defs/Artifact" + }, + "snapshotId": { + "type": "string" + }, + "endTurn": { + "type": "boolean" + } + }, + "additionalProperties": false + }, + "Artifact": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parts": { + "type": "array", + "items": { + "$ref": "#/$defs/Part" + } + }, + "metadata": { + "type": "object", + "additionalProperties": {} + } + }, + "required": [ + "parts" + ], + "additionalProperties": false + }, + "SessionState": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "$ref": "#/$defs/Message" + } + }, + "custom": {}, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + }, + "inputVariables": {} + }, + "additionalProperties": false + }, + "SnapshotEvent": { + "type": "string", + "enum": [ + "turnEnd", + "invocationEnd" + ] + }, "DocumentData": { "type": "object", "properties": { diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 48df79b56a..e12e44fc8e 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -22,6 +22,7 @@ import { zodToJsonSchema } from 'zod-to-json-schema'; /** List of files that contain types to be exported. */ const EXPORTED_TYPE_MODULES = [ + '../common/src/types/agent.ts', '../common/src/types/document.ts', '../common/src/types/embedder.ts', '../common/src/types/evaluator.ts', diff --git a/go/ai/generate.go b/go/ai/generate.go index 6aeb1e6642..ee91edbb5b 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -1100,6 +1100,27 @@ func (m *Message) Text() string { return sb.String() } +// NewResume constructs a [GenerateActionResume] from Part slices. +// This is useful when building [GenerateActionOptions] directly (e.g., from a +// rendered prompt) and need to set the Resume field from [*Part] values +// produced by [ToolDef.RestartWith] or [ToolDef.RespondWith]. +func NewResume(restarts, responds []*Part) *GenerateActionResume { + resume := &GenerateActionResume{} + for _, p := range restarts { + resume.Restart = append(resume.Restart, &toolRequestPart{ + ToolRequest: p.ToolRequest, + Metadata: p.Metadata, + }) + } + for _, p := range responds { + resume.Respond = append(resume.Respond, &toolResponsePart{ + ToolResponse: p.ToolResponse, + Metadata: p.Metadata, + }) + } + return resume +} + // NewModelRef creates a new ModelRef with the given name and configuration. func NewModelRef(name string, config any) ModelRef { return ModelRef{name: name, config: config} diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent_flow.go index dfe6d4a75a..09ff248460 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent_flow.go @@ -35,75 +35,6 @@ import ( "github.com/google/uuid" ) -// Artifact represents a named collection of parts produced during a session. -// Examples: generated files, images, code snippets, diagrams, etc. -type Artifact struct { - // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). - Name string `json:"name,omitempty"` - // Parts contains the artifact content (text, media, etc.). - Parts []*ai.Part `json:"parts"` - // Metadata contains additional artifact-specific data. - Metadata map[string]any `json:"metadata,omitempty"` -} - -// AgentFlowInput is the input sent to an agent flow during a conversation turn. -type AgentFlowInput struct { - // Messages contains the user's input for this turn. - Messages []*ai.Message `json:"messages,omitempty"` -} - -// AgentFlowInit is the input for starting an agent flow invocation. -// Provide either SnapshotID (to load from store) or State (direct state). -type AgentFlowInit[State any] struct { - // SnapshotID loads state from a persisted snapshot. - // Mutually exclusive with State. - SnapshotID string `json:"snapshotId,omitempty"` - // State provides direct state for the invocation. - // Mutually exclusive with SnapshotID. - State *SessionState[State] `json:"state,omitempty"` -} - -// AgentFlowResult is the return value from an AgentFlowFunc. -// It contains the user-specified outputs of the agent invocation. -type AgentFlowResult struct { - // Message is the last model response message from the conversation. - Message *ai.Message `json:"message,omitempty"` - // Artifacts contains artifacts produced during the session. - Artifacts []*Artifact `json:"artifacts,omitempty"` -} - -// AgentFlowOutput is the output when an agent flow invocation completes. -// It wraps AgentFlowResult with framework-managed fields. -type AgentFlowOutput[State any] struct { - // SnapshotID is the ID of the snapshot created at the end of this invocation. - // Empty if no snapshot was created (callback returned false or no store configured). - SnapshotID string `json:"snapshotId,omitempty"` - // State contains the final conversation state. - // Only populated when state is client-managed (no store configured). - State *SessionState[State] `json:"state,omitempty"` - // Message is the last model response message from the conversation. - Message *ai.Message `json:"message,omitempty"` - // Artifacts contains artifacts produced during the session. - Artifacts []*Artifact `json:"artifacts,omitempty"` -} - -// AgentFlowStreamChunk represents a single item in the agent flow's output stream. -// Multiple fields can be populated in a single chunk. -type AgentFlowStreamChunk[Stream any] struct { - // ModelChunk contains generation tokens from the model. - ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` - // Status contains user-defined structured status information. - // The Stream type parameter defines the shape of this data. - Status Stream `json:"status,omitempty"` - // Artifact contains a newly produced artifact. - Artifact *Artifact `json:"artifact,omitempty"` - // SnapshotID contains the ID of a snapshot that was just persisted. - SnapshotID string `json:"snapshotId,omitempty"` - // EndTurn signals that the agent flow has finished processing the current input. - // When true, the client should stop iterating and may send the next input. - EndTurn bool `json:"endTurn,omitempty"` -} - // --- AgentSession --- // AgentSession extends Session with agent-flow-specific functionality: @@ -424,6 +355,17 @@ func DefinePromptAgent[State, PromptIn any]( // Append conversation history after the prompt-rendered messages. genOpts.Messages = append(genOpts.Messages, sess.Messages()...) + // If tool restarts were provided, set the resume option so + // handleResumeOption re-executes the interrupted tools. + if len(input.ToolRestarts) > 0 { + for _, p := range input.ToolRestarts { + if !p.IsToolRequest() { + return core.NewError(core.INVALID_ARGUMENT, "ToolRestarts: part is not a tool request") + } + } + genOpts.Resume = ai.NewResume(input.ToolRestarts, nil) + } + // Call the model with streaming. modelResp, err := ai.GenerateWithRequest(ctx, r, genOpts, nil, func(ctx context.Context, chunk *ai.ModelResponseChunk) error { @@ -593,6 +535,12 @@ func (c *AgentFlowConnection[Stream, State]) SendText(text string) error { }) } +// SendToolRestarts sends tool restart parts to resume interrupted tool calls. +// Parts should be created via [ai.ToolDef.RestartWith]. +func (c *AgentFlowConnection[Stream, State]) SendToolRestarts(parts ...*ai.Part) error { + return c.conn.Send(&AgentFlowInput{ToolRestarts: parts}) +} + // Close signals that no more inputs will be sent. func (c *AgentFlowConnection[Stream, State]) Close() error { return c.conn.Close() diff --git a/go/ai/x/gen.go b/go/ai/x/gen.go new file mode 100644 index 0000000000..ff5ace986c --- /dev/null +++ b/go/ai/x/gen.go @@ -0,0 +1,122 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// This file was generated by jsonschemagen. DO NOT EDIT. + +package aix + +import ( + "github.com/firebase/genkit/go/ai" +) + +// AgentFlowInit is the input for starting an agent flow invocation. +// Provide either SnapshotID (to load from store) or State (direct state). +type AgentFlowInit[State any] struct { + // SnapshotID loads state from a persisted snapshot. + // Mutually exclusive with State. + SnapshotID string `json:"snapshotId,omitempty"` + // State provides direct state for the invocation. + // Mutually exclusive with SnapshotID. + State *SessionState[State] `json:"state,omitempty"` +} + +// AgentFlowInput is the input sent to an agent flow during a conversation turn. +type AgentFlowInput struct { + // Messages contains the user's input for this turn. + Messages []*ai.Message `json:"messages,omitempty"` + // ToolRestarts contains tool request parts to re-execute interrupted tools. + // Use [ai.ToolDef.RestartWith] to create these parts from an interrupted + // tool request. When set, the generate call resumes with these restarts + // instead of treating Messages as tool responses. + ToolRestarts []*ai.Part `json:"toolRestarts,omitempty"` +} + +// AgentFlowOutput is the output when an agent flow invocation completes. +// It wraps AgentFlowResult with framework-managed fields. +type AgentFlowOutput[State any] struct { + // Artifacts contains artifacts produced during the session. + Artifacts []*Artifact `json:"artifacts,omitempty"` + // Message is the last model response message from the conversation. + Message *ai.Message `json:"message,omitempty"` + // SnapshotID is the ID of the snapshot created at the end of this invocation. + // Empty if no snapshot was created (callback returned false or no store configured). + SnapshotID string `json:"snapshotId,omitempty"` + // State contains the final conversation state. + // Only populated when state is client-managed (no store configured). + State *SessionState[State] `json:"state,omitempty"` +} + +// AgentFlowResult is the return value from an AgentFlowFunc. +// It contains the user-specified outputs of the agent invocation. +type AgentFlowResult struct { + // Artifacts contains artifacts produced during the session. + Artifacts []*Artifact `json:"artifacts,omitempty"` + // Message is the last model response message from the conversation. + Message *ai.Message `json:"message,omitempty"` +} + +// AgentFlowStreamChunk represents a single item in the agent flow's output stream. +// Multiple fields can be populated in a single chunk. +type AgentFlowStreamChunk[Stream any] struct { + // Artifact contains a newly produced artifact. + Artifact *Artifact `json:"artifact,omitempty"` + // EndTurn signals that the agent flow has finished processing the current input. + // When true, the client should stop iterating and may send the next input. + EndTurn bool `json:"endTurn,omitempty"` + // ModelChunk contains generation tokens from the model. + ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` + // SnapshotID contains the ID of a snapshot that was just persisted. + SnapshotID string `json:"snapshotId,omitempty"` + // Status contains user-defined structured status information. + // The Stream type parameter defines the shape of this data. + Status Stream `json:"status,omitempty"` +} + +// Artifact represents a named collection of parts produced during a session. +// Examples: generated files, images, code snippets, diagrams, etc. +type Artifact struct { + // Metadata contains additional artifact-specific data. + Metadata map[string]any `json:"metadata,omitempty"` + // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). + Name string `json:"name,omitempty"` + // Parts contains the artifact content (text, media, etc.). + Parts []*ai.Part `json:"parts"` +} + +// SessionState is the portable conversation state that flows between client +// and server. It contains only the data needed for conversation continuity. +type SessionState[State any] struct { + // Artifacts are named collections of parts produced during the conversation. + Artifacts []*Artifact `json:"artifacts,omitempty"` + // Custom is the user-defined state associated with this conversation. + Custom State `json:"custom,omitempty"` + // InputVariables is the input used for agent flows that require input variables + // (e.g. prompt-backed agent flows). + InputVariables any `json:"inputVariables,omitempty"` + // Messages is the conversation history (user/model exchanges). + // Does NOT include prompt-rendered messages — those are rendered fresh each turn. + Messages []*ai.Message `json:"messages,omitempty"` +} + +// SnapshotEvent identifies what triggered a snapshot. +type SnapshotEvent string + +const ( + // TurnEnd indicates the snapshot was triggered at the end of a turn. + SnapshotEventTurnEnd SnapshotEvent = "turnEnd" + // InvocationEnd indicates the snapshot was triggered at the end of the invocation. + SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" +) diff --git a/go/ai/x/session.go b/go/ai/x/session.go index 6df97a7c67..d158c90b16 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -26,31 +26,6 @@ import ( "github.com/firebase/genkit/go/ai" ) -// SessionState is the portable conversation state that flows between client -// and server. It contains only the data needed for conversation continuity. -type SessionState[State any] struct { - // Messages is the conversation history (user/model exchanges). - // Does NOT include prompt-rendered messages — those are rendered fresh each turn. - Messages []*ai.Message `json:"messages,omitempty"` - // Custom is the user-defined state associated with this conversation. - Custom State `json:"custom,omitempty"` - // Artifacts are named collections of parts produced during the conversation. - Artifacts []*Artifact `json:"artifacts,omitempty"` - // InputVariables is the input used for agent flows that require input variables - // (e.g. prompt-backed agent flows). - InputVariables any `json:"inputVariables,omitempty"` -} - -// SnapshotEvent identifies what triggered a snapshot. -type SnapshotEvent string - -const ( - // TurnEnd indicates the snapshot was triggered at the end of a turn. - SnapshotEventTurnEnd SnapshotEvent = "turnEnd" - // InvocationEnd indicates the snapshot was triggered at the end of the invocation. - SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" -) - // SessionSnapshot is a persisted point-in-time capture of session state. type SessionSnapshot[State any] struct { // SnapshotID is the unique identifier for this snapshot (UUID). diff --git a/go/core/schemas.config b/go/core/schemas.config index 70798f2eb3..b58c36e2a9 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1108,3 +1108,222 @@ Embedding.embedding type []float32 GenkitError omit GenkitErrorData omit GenkitErrorDataGenkitErrorDetails omit + +# ============================================================================ +# AGENT FLOW TYPES (generated into ai/x package) +# ============================================================================ + +# Package configuration: ai/x directory uses "aix" as Go package name. +ai/x name aix +aix import github.com/firebase/genkit/go/ai + +# ---------------------------------------------------------------------------- +# Artifact +# ---------------------------------------------------------------------------- + +Artifact pkg ai/x + +Artifact doc +Artifact represents a named collection of parts produced during a session. +Examples: generated files, images, code snippets, diagrams, etc. +. + +Artifact.name doc +Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). +. + +Artifact.parts type []*ai.Part +Artifact.parts noomitempty +Artifact.parts doc +Parts contains the artifact content (text, media, etc.). +. + +Artifact.metadata type map[string]any +Artifact.metadata doc +Metadata contains additional artifact-specific data. +. + +# ---------------------------------------------------------------------------- +# AgentFlowInput +# ---------------------------------------------------------------------------- + +AgentFlowInput pkg ai/x + +AgentFlowInput doc +AgentFlowInput is the input sent to an agent flow during a conversation turn. +. + +AgentFlowInput.messages type []*ai.Message +AgentFlowInput.messages doc +Messages contains the user's input for this turn. +. + +AgentFlowInput.toolRestarts type []*ai.Part +AgentFlowInput.toolRestarts doc +ToolRestarts contains tool request parts to re-execute interrupted tools. +Use [ai.ToolDef.RestartWith] to create these parts from an interrupted +tool request. When set, the generate call resumes with these restarts +instead of treating Messages as tool responses. +. + +# ---------------------------------------------------------------------------- +# AgentFlowInit +# ---------------------------------------------------------------------------- + +AgentFlowInit pkg ai/x +AgentFlowInit typeparams [State any] + +AgentFlowInit doc +AgentFlowInit is the input for starting an agent flow invocation. +Provide either SnapshotID (to load from store) or State (direct state). +. + +AgentFlowInit.snapshotId doc +SnapshotID loads state from a persisted snapshot. +Mutually exclusive with State. +. + +AgentFlowInit.state type *SessionState[State] +AgentFlowInit.state doc +State provides direct state for the invocation. +Mutually exclusive with SnapshotID. +. + +# ---------------------------------------------------------------------------- +# AgentFlowResult +# ---------------------------------------------------------------------------- + +AgentFlowResult pkg ai/x + +AgentFlowResult doc +AgentFlowResult is the return value from an AgentFlowFunc. +It contains the user-specified outputs of the agent invocation. +. + +AgentFlowResult.message type *ai.Message +AgentFlowResult.message doc +Message is the last model response message from the conversation. +. + +AgentFlowResult.artifacts doc +Artifacts contains artifacts produced during the session. +. + +# ---------------------------------------------------------------------------- +# AgentFlowOutput +# ---------------------------------------------------------------------------- + +AgentFlowOutput pkg ai/x +AgentFlowOutput typeparams [State any] + +AgentFlowOutput doc +AgentFlowOutput is the output when an agent flow invocation completes. +It wraps AgentFlowResult with framework-managed fields. +. + +AgentFlowOutput.snapshotId doc +SnapshotID is the ID of the snapshot created at the end of this invocation. +Empty if no snapshot was created (callback returned false or no store configured). +. + +AgentFlowOutput.state type *SessionState[State] +AgentFlowOutput.state doc +State contains the final conversation state. +Only populated when state is client-managed (no store configured). +. + +AgentFlowOutput.message type *ai.Message +AgentFlowOutput.message doc +Message is the last model response message from the conversation. +. + +AgentFlowOutput.artifacts doc +Artifacts contains artifacts produced during the session. +. + +# ---------------------------------------------------------------------------- +# AgentFlowStreamChunk +# ---------------------------------------------------------------------------- + +AgentFlowStreamChunk pkg ai/x +AgentFlowStreamChunk typeparams [Stream any] + +AgentFlowStreamChunk doc +AgentFlowStreamChunk represents a single item in the agent flow's output stream. +Multiple fields can be populated in a single chunk. +. + +AgentFlowStreamChunk.modelChunk type *ai.ModelResponseChunk +AgentFlowStreamChunk.modelChunk doc +ModelChunk contains generation tokens from the model. +. + +AgentFlowStreamChunk.status type Stream +AgentFlowStreamChunk.status doc +Status contains user-defined structured status information. +The Stream type parameter defines the shape of this data. +. + +AgentFlowStreamChunk.artifact doc +Artifact contains a newly produced artifact. +. + +AgentFlowStreamChunk.snapshotId doc +SnapshotID contains the ID of a snapshot that was just persisted. +. + +AgentFlowStreamChunk.endTurn doc +EndTurn signals that the agent flow has finished processing the current input. +When true, the client should stop iterating and may send the next input. +. + +# ---------------------------------------------------------------------------- +# SessionState +# ---------------------------------------------------------------------------- + +SessionState pkg ai/x +SessionState typeparams [State any] + +SessionState doc +SessionState is the portable conversation state that flows between client +and server. It contains only the data needed for conversation continuity. +. + +SessionState.messages type []*ai.Message +SessionState.messages doc +Messages is the conversation history (user/model exchanges). +Does NOT include prompt-rendered messages — those are rendered fresh each turn. +. + +SessionState.custom type State +SessionState.custom doc +Custom is the user-defined state associated with this conversation. +. + +SessionState.artifacts doc +Artifacts are named collections of parts produced during the conversation. +. + +SessionState.inputVariables doc +InputVariables is the input used for agent flows that require input variables +(e.g. prompt-backed agent flows). +. + +# ---------------------------------------------------------------------------- +# SnapshotEvent +# ---------------------------------------------------------------------------- + +SnapshotEvent pkg ai/x + +SnapshotEvent doc +SnapshotEvent identifies what triggered a snapshot. +. + +SnapshotEventTurnEnd doc +TurnEnd indicates the snapshot was triggered at the end of a turn. +. + +SnapshotEventInvocationEnd doc +InvocationEnd indicates the snapshot was triggered at the end of the invocation. +. + diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index 6ed2d5a8e7..c1c9a3fb5e 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -131,9 +131,14 @@ func run(infile, defaultPkgPath, configFile, outRoot string) error { // Generate code by package. for pkgPath, schemaMap := range schemasByPackage { + // Derive package name from path, with config override. + pkgName := path.Base(pkgPath) + if pc := cfg.configFor(pkgPath); pc != nil && pc.name != "" { + pkgName = pc.name + } // Generate code for each type in the package. gen := &generator{ - pkgName: path.Base(pkgPath), + pkgName: pkgName, schemas: schemaMap, cfg: cfg, } @@ -292,7 +297,15 @@ func (g *generator) generate() ([]byte, error) { g.pr("package %s\n\n", g.pkgName) if pc := g.cfg.configFor(g.pkgName); pc != nil { - g.pr("import %q\n", pc.pkgPath) + if len(pc.imports) > 0 { + g.pr("import (\n") + for _, imp := range pc.imports { + g.pr(" %q\n", imp) + } + g.pr(")\n\n") + } else if pc.pkgPath != "" { + g.pr("import %q\n", pc.pkgPath) + } } // Sort the names so the output is deterministic. @@ -386,7 +399,7 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err if goName == "" { goName = adjustIdentifier(name) } - g.pr("type %s struct {\n", goName) + g.pr("type %s%s struct {\n", goName, tcfg.typeparams) for _, field := range sortedKeys(s.Properties) { fcfg := g.cfg.configFor(name + "." + field) if fcfg == nil { @@ -416,7 +429,7 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err g.generateDoc(fs, fcfg) jsonTag := fmt.Sprintf(`json:"%s,omitempty"`, field) - if skipOmitEmpty(goName, field) { + if skipOmitEmpty(goName, field) || fcfg.noOmitEmpty { jsonTag = fmt.Sprintf(`json:"%s"`, field) } g.pr(fmt.Sprintf(" %s %s `%s`\n", adjustIdentifier(field), typeExpr, jsonTag)) @@ -603,12 +616,15 @@ func (c config) configFor(name string) *itemConfig { // itemConfig is configuration for one item: a type, a field or a package. // Not all itemConfig fields apply to both, but using one type simplifies the parser. type itemConfig struct { - omit bool - name string - pkgPath string - typeExpr string - docLines []string - fields []extraField + omit bool + name string + pkgPath string + typeExpr string + docLines []string + fields []extraField + typeparams string // Go type parameters (e.g., "[State any]") + noOmitEmpty bool // omit the omitempty tag for this field + imports []string // import paths for the package } // extraField represents an additional unexported field to add to a struct. @@ -636,7 +652,11 @@ type extraField struct { // pkg // package path, relative to outdir (last component is package name) // import -// path of package to import (for packages only) +// path of package to import (for packages only, may be repeated) +// typeparams PARAMS +// Go type parameters to add to the type declaration (e.g., "[State any]") +// noomitempty +// don't add omitempty to this field's json tag // field NAME TYPE // add an unexported field to the struct (for types only) func parseConfigFile(filename string) (config, error) { @@ -703,7 +723,14 @@ func parseConfigFile(filename string) (config, error) { if len(words) < 3 { return errf("need NAME import PATH") } - ic.pkgPath = words[2] + ic.imports = append(ic.imports, words[2]) + case "typeparams": + if len(words) < 3 { + return errf("need NAME typeparams PARAMS") + } + ic.typeparams = strings.Join(words[2:], " ") + case "noomitempty": + ic.noOmitEmpty = true case "field": if len(words) < 4 { return errf("need NAME field FIELDNAME TYPE") From d9c335a7102dc3a5e0ae5f4ec7a134a2f41b9484 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 11:38:45 -0800 Subject: [PATCH 030/141] Update typing.py --- py/packages/genkit/src/genkit/core/typing.py | 71 ++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 05d623c07d..37464bce79 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -39,6 +39,13 @@ class Model(RootModel[Any]): root: Any +class SnapshotEvent(StrEnum): + """SnapshotEvent data type class.""" + + TURN_END = 'turnEnd' + INVOCATION_END = 'invocationEnd' + + class Embedding(BaseModel): """Model for embedding data.""" @@ -880,6 +887,15 @@ class Content(RootModel[list[Part]]): root: list[Part] +class Artifact(BaseModel): + """Model for artifact data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str | None = None + parts: list[Part] + metadata: dict[str, Any] | None = None + + class DocumentData(BaseModel): """Model for documentdata data.""" @@ -971,6 +987,43 @@ class Messages(RootModel[list[Message]]): root: list[Message] +class AgentFlowInput(BaseModel): + """Model for agentflowinput data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + messages: list[Message] | None = None + tool_restarts: list[Part] | None = Field(default=None) + + +class AgentFlowResult(BaseModel): + """Model for agentflowresult data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + message: Message | None = None + artifacts: list[Artifact] | None = None + + +class AgentFlowStreamChunk(BaseModel): + """Model for agentflowstreamchunk data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + model_chunk: ModelResponseChunk | None = Field(default=None) + status: Any | None = None + artifact: Artifact | None = None + snapshot_id: str | None = Field(default=None) + end_turn: bool | None = Field(default=None) + + +class SessionState(BaseModel): + """Model for sessionstate data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + messages: list[Message] | None = None + custom: Any | None = None + artifacts: list[Artifact] | None = None + input_variables: Any | None = Field(default=None) + + class Candidate(BaseModel): """Model for candidate data.""" @@ -1048,6 +1101,24 @@ class Request(RootModel[GenerateRequest]): root: GenerateRequest +class AgentFlowInit(BaseModel): + """Model for agentflowinit data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str | None = Field(default=None) + state: SessionState | None = None + + +class AgentFlowOutput(BaseModel): + """Model for agentflowoutput data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str | None = Field(default=None) + state: SessionState | None = None + message: Message | None = None + artifacts: list[Artifact] | None = None + + class ModelResponse(BaseModel): """Model for modelresponse data.""" From fb7f254cc789c30e7a1602e6e22a4269a975da42 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 11:54:13 -0800 Subject: [PATCH 031/141] various renames --- go/ai/x/{agent_flow.go => agent.go} | 2 +- go/ai/x/{agent_flow_test.go => agent_test.go} | 4 +- go/ai/x/session.go | 43 +++++++------------ 3 files changed, 20 insertions(+), 29 deletions(-) rename go/ai/x/{agent_flow.go => agent.go} (99%) rename go/ai/x/{agent_flow_test.go => agent_test.go} (99%) diff --git a/go/ai/x/agent_flow.go b/go/ai/x/agent.go similarity index 99% rename from go/ai/x/agent_flow.go rename to go/ai/x/agent.go index 09ff248460..7c7548a73f 100644 --- a/go/ai/x/agent_flow.go +++ b/go/ai/x/agent.go @@ -390,7 +390,7 @@ func DefinePromptAgent[State, PromptIn any]( } msgs = append(msgs, m) } - sess.SetMessages(msgs) + sess.UpdateMessages(func(_ []*ai.Message) []*ai.Message { return msgs }) } else if modelResp.Message != nil { sess.AddMessages(modelResp.Message) } diff --git a/go/ai/x/agent_flow_test.go b/go/ai/x/agent_test.go similarity index 99% rename from go/ai/x/agent_flow_test.go rename to go/ai/x/agent_test.go index 5762713254..0a108dc9c2 100644 --- a/go/ai/x/agent_flow_test.go +++ b/go/ai/x/agent_test.go @@ -593,7 +593,9 @@ func TestAgentFlow_SetMessages(t *testing.T) { func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Replace all messages with just one. - sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) + sess.UpdateMessages(func(_ []*ai.Message) []*ai.Message { + return []*ai.Message{ai.NewModelTextMessage("replaced")} + }) return nil }) }, diff --git a/go/ai/x/session.go b/go/ai/x/session.go index d158c90b16..a4418532e8 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -24,8 +24,11 @@ import ( "time" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/base" ) +// --- Snapshot --- + // SessionSnapshot is a persisted point-in-time capture of session state. type SessionSnapshot[State any] struct { // SnapshotID is the unique identifier for this snapshot (UUID). @@ -56,6 +59,8 @@ type SnapshotContext[State any] struct { // If not provided and a store is configured, snapshots are always created. type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool +// --- Session store --- + // SessionStore persists and retrieves snapshots. type SessionStore[State any] interface { // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. @@ -157,11 +162,12 @@ func (s *Session[State]) AddMessages(messages ...*ai.Message) { s.state.Messages = append(s.state.Messages, messages...) } -// SetMessages replaces the entire conversation history. -func (s *Session[State]) SetMessages(messages []*ai.Message) { +// UpdateMessages atomically reads the current messages, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateMessages(fn func([]*ai.Message) []*ai.Message) { s.mu.Lock() defer s.mu.Unlock() - s.state.Messages = messages + s.state.Messages = fn(s.state.Messages) } // Custom returns the current user-defined custom state. @@ -171,13 +177,6 @@ func (s *Session[State]) Custom() State { return s.state.Custom } -// SetCustom updates the user-defined custom state. -func (s *Session[State]) SetCustom(custom State) { - s.mu.Lock() - defer s.mu.Unlock() - s.state.Custom = custom -} - // UpdateCustom atomically reads the current custom state, applies the given // function, and writes the result back. func (s *Session[State]) UpdateCustom(fn func(State) State) { @@ -224,11 +223,12 @@ func (s *Session[State]) AddArtifacts(artifacts ...*Artifact) { } } -// SetArtifacts replaces the entire artifact list. -func (s *Session[State]) SetArtifacts(artifacts []*Artifact) { +// UpdateArtifacts atomically reads the current artifacts, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateArtifacts(fn func([]*Artifact) []*Artifact) { s.mu.Lock() defer s.mu.Unlock() - s.state.Artifacts = artifacts + s.state.Artifacts = fn(s.state.Artifacts) } // copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). @@ -246,27 +246,16 @@ func (s *Session[State]) copyStateLocked() SessionState[State] { // --- Session context --- -type sessionContextKey struct{} - -type sessionHolder struct { - session any -} +var sessionCtxKey = base.NewContextKey[any]() // NewSessionContext returns a new context with the session attached. func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { - return context.WithValue(ctx, sessionContextKey{}, &sessionHolder{session: s}) + return sessionCtxKey.NewContext(ctx, s) } // SessionFromContext retrieves the current session from context. // Returns nil if no session is in context or if the type doesn't match. func SessionFromContext[State any](ctx context.Context) *Session[State] { - holder, ok := ctx.Value(sessionContextKey{}).(*sessionHolder) - if !ok || holder == nil { - return nil - } - session, ok := holder.session.(*Session[State]) - if !ok { - return nil - } + session, _ := sessionCtxKey.FromContext(ctx).(*Session[State]) return session } From 971b6688ecbe7451a56fc50f77d152af4c05ba87 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 18 Feb 2026 16:06:21 -0800 Subject: [PATCH 032/141] helper for input variables conversion --- go/ai/tools.go | 17 ++--------------- go/ai/x/agent.go | 5 +++-- go/internal/base/json.go | 18 ++++++++++++++++++ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/go/ai/tools.go b/go/ai/tools.go index 453ad779a3..6099d1eff5 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -248,8 +248,7 @@ func ResumedValue[T any](tc *ToolContext, key string) (T, bool) { if !ok { return zero, false } - typed, ok := v.(T) - return typed, ok + return base.ConvertTo[T](v) } // OriginalInputAs returns the original input typed appropriately. @@ -259,19 +258,7 @@ func OriginalInputAs[T any](tc *ToolContext) (T, bool) { if tc.OriginalInput == nil { return zero, false } - // Try direct type assertion first (for when input is already typed) - if typed, ok := tc.OriginalInput.(T); ok { - return typed, ok - } - // Otherwise try to convert from map[string]any (common case from JSON) - if m, ok := tc.OriginalInput.(map[string]any); ok { - result, err := base.MapToStruct[T](m) - if err != nil { - return zero, false - } - return result, true - } - return zero, false + return base.ConvertTo[T](tc.OriginalInput) } // DefineTool creates a new [ToolDef] and registers it. diff --git a/go/ai/x/agent.go b/go/ai/x/agent.go index 7c7548a73f..10a43b2f77 100644 --- a/go/ai/x/agent.go +++ b/go/ai/x/agent.go @@ -32,6 +32,7 @@ import ( "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal/base" "github.com/google/uuid" ) @@ -330,9 +331,9 @@ func DefinePromptAgent[State, PromptIn any]( // Resolve prompt input: session state override > default. promptInput := defaultInput if stored := sess.InputVariables(); stored != nil { - typed, ok := stored.(PromptIn) + typed, ok := base.ConvertTo[PromptIn](stored) if !ok { - return core.NewError(core.INVALID_ARGUMENT, "prompt input type mismatch: got %T, want %T", stored, promptInput) + return core.NewError(core.INVALID_ARGUMENT, "input variables type mismatch: got %T, want %T", stored, promptInput) } promptInput = typed } diff --git a/go/internal/base/json.go b/go/internal/base/json.go index dff45c260a..a52ba3491b 100644 --- a/go/internal/base/json.go +++ b/go/internal/base/json.go @@ -118,6 +118,24 @@ func InferJSONSchema(x any) (s *jsonschema.Schema) { return s } +// ConvertTo attempts to convert a value to type T. It tries a direct type +// assertion first, then falls back to a JSON round-trip for values that were +// deserialized from JSON (e.g., map[string]any instead of a concrete struct). +func ConvertTo[T any](v any) (T, bool) { + if typed, ok := v.(T); ok { + return typed, true + } + var result T + data, err := json.Marshal(v) + if err != nil { + return result, false + } + if err := json.Unmarshal(data, &result); err != nil { + return result, false + } + return result, true +} + // MapToStruct converts a map[string]any to a struct of type T via JSON round-trip. func MapToStruct[T any](m map[string]any) (T, error) { var result T From 3491761cfa736e6184f7fcda1a0b99922a4e5c44 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 19 Feb 2026 07:26:36 -0800 Subject: [PATCH 033/141] DefinePromptAgent takes in prompt name instead of resolved prompt --- go/ai/x/agent.go | 22 ++--- go/ai/x/agent_test.go | 26 ++--- go/genkit/genkit.go | 162 ++++++++++++++++++++++++-------- go/samples/prompt-agent/main.go | 4 +- 4 files changed, 149 insertions(+), 65 deletions(-) diff --git a/go/ai/x/agent.go b/go/ai/x/agent.go index 10a43b2f77..5775b48f00 100644 --- a/go/ai/x/agent.go +++ b/go/ai/x/agent.go @@ -300,13 +300,6 @@ func DefineCustomAgent[Stream, State any]( return &AgentFlow[Stream, State]{flow: flow} } -// PromptRenderer renders a prompt with typed input into GenerateActionOptions. -// This interface is satisfied by both ai.Prompt (with In=any) and -// *ai.DataPrompt[In, Out]. -type PromptRenderer[In any] interface { - Render(ctx context.Context, input In) (*ai.GenerateActionOptions, error) -} - // promptMessageKey is the metadata key used to tag prompt-rendered messages // so they can be excluded from session history after generation. const promptMessageKey = "_genkit_prompt" @@ -316,15 +309,20 @@ const promptMessageKey = "_genkit_prompt" // conversation history, calls GenerateWithRequest, streams chunks to the // client, and adds the model response to the session. // -// The defaultInput is used for prompt rendering unless overridden per -// invocation via WithPromptInput. +// The prompt is looked up by name from the registry using +// [ai.LookupDataPrompt]. The defaultInput is used for prompt rendering +// unless overridden per invocation via WithInputVariables. func DefinePromptAgent[State, PromptIn any]( r api.Registry, - name string, - p PromptRenderer[PromptIn], + promptName string, defaultInput PromptIn, opts ...AgentFlowOption[State], ) *AgentFlow[any, State] { + p := ai.LookupDataPrompt[PromptIn, string](r, promptName) + if p == nil { + panic(fmt.Sprintf("DefinePromptAgent: prompt %q not found", promptName)) + } + fn := func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) (*AgentFlowResult, error) { var lastModelMessage *ai.Message err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { @@ -416,7 +414,7 @@ func DefinePromptAgent[State, PromptIn any]( }, nil } - return DefineCustomAgent(r, name, fn, opts...) + return DefineCustomAgent(r, promptName, fn, opts...) } // StreamBidi starts a new agent flow invocation. diff --git a/go/ai/x/agent_test.go b/go/ai/x/agent_test.go index 0a108dc9c2..5edd805f4f 100644 --- a/go/ai/x/agent_test.go +++ b/go/ai/x/agent_test.go @@ -913,13 +913,13 @@ func TestPromptAgent_Basic(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) - prompt := ai.DefinePrompt(reg, "testPrompt", + ai.DefinePrompt(reg, "testPrompt", ai.WithModelName("test/echo"), ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState]( - reg, "promptFlow", prompt, nil, + af := DefinePromptAgent[testState, any]( + reg, "testPrompt", nil, ) conn, err := af.StreamBidi(ctx) @@ -985,13 +985,13 @@ func TestPromptAgent_PromptInputOverride(t *testing.T) { Name string `json:"name"` } - prompt := ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", + ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", ai.WithModelName("test/echo"), ai.WithPrompt("Hello {{name}}!"), ) af := DefinePromptAgent[testState]( - reg, "promptInputFlow", prompt, greetInput{Name: "default"}, + reg, "greetPrompt", greetInput{Name: "default"}, ) // Use WithPromptInput to override. @@ -1063,13 +1063,13 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { }, ) - prompt := ai.DefinePrompt(reg, "historyPrompt", + ai.DefinePrompt(reg, "historyPrompt", ai.WithModelName("test/history"), ai.WithSystem("system prompt"), ) - af := DefinePromptAgent[testState]( - reg, "historyFlow", prompt, nil, + af := DefinePromptAgent[testState, any]( + reg, "historyPrompt", nil, ) conn, err := af.StreamBidi(ctx) @@ -1139,13 +1139,13 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { reg := setupPromptTestRegistry(t) store := NewInMemorySessionStore[testState]() - prompt := ai.DefinePrompt(reg, "snapPrompt", + ai.DefinePrompt(reg, "snapPrompt", ai.WithModelName("test/echo"), ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent( - reg, "snapPromptFlow", prompt, nil, + af := DefinePromptAgent[testState, any]( + reg, "snapPrompt", nil, WithSessionStore(store), ) @@ -1270,13 +1270,13 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ) ai.DefineGenerateAction(ctx, reg) - prompt := ai.DefinePrompt(reg, "toolPrompt", + ai.DefinePrompt(reg, "toolPrompt", ai.WithModelName("test/toolmodel"), ai.WithSystem("You are a test assistant."), ai.WithTools(ai.ToolName("greet")), ) - af := DefinePromptAgent[testState](reg, "toolFlow", prompt, nil) + af := DefinePromptAgent[testState, any](reg, "toolPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index c2ad893eef..0455eac406 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -187,7 +187,7 @@ func WithPromptFS(fsys fs.FS) GenkitOption { // // Assumes a prompt file at ./prompts/jokePrompt.prompt // g := genkit.Init(ctx, // genkit.WithPlugins(&googlegenai.GoogleAI{}), -// genkit.WithDefaultModel("googleai/gemini-2.5-flash"), +// genkit.WithDefaultModel("googleai/gemini-3-flash-preview"), // genkit.WithPromptDir("./prompts"), // ) // @@ -408,36 +408,79 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B return core.DefineBidiFlow(g.reg, name, fn) } -// DefineCustomAgent creates an AgentFlow with automatic snapshot management -// and registers it as a flow action. +// DefineCustomAgent defines a custom agent flow with full control over the +// conversation loop, registers it as a [core.Action] of type Flow, and +// returns an [aix.AgentFlow]. // -// An AgentFlow is a stateful, multi-turn conversational flow with automatic -// snapshot persistence and turn semantics. It builds on bidirectional streaming -// to enable ongoing conversations with managed state. +// An AgentFlow is a stateful, multi-turn conversational flow. It builds on +// bidirectional streaming to enable ongoing conversations where each turn's +// input and output are streamed between client and server. The framework +// handles session state, conversation history, and optional snapshot +// persistence automatically. +// +// The provided function fn receives a [aix.Responder] for streaming output +// to the client and an [aix.AgentSession] for accessing conversation state. +// Call [aix.AgentSession.Run] to enter the turn loop, which blocks until the +// client sends the next message. +// +// For prompt-backed agents that follow a standard render-generate-stream loop, +// use [DefinePromptAgent] instead. +// +// # Options +// +// - [aix.WithSessionStore]: Enable snapshot persistence with a [aix.SessionStore] +// - [aix.WithSnapshotCallback]: Control when snapshots are created +// - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types // // Type parameters: -// - Stream: Type for status updates sent via the responder -// - State: Type for user-defined state in snapshots +// - Stream: Type for custom status updates sent via [aix.Responder.SendStatus] +// - State: Type for user-defined state persisted in snapshots // // Example: // -// type ChatState struct { -// TopicHistory []string `json:"topicHistory,omitempty"` -// } +// chatAgent := genkit.DefineCustomAgent(g, "chat", +// func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentFlowResult, error) { +// var lastMessage *ai.Message +// err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { +// sess.AddMessages(input.Messages...) +// for result, err := range genkit.GenerateStream(ctx, g, +// ai.WithModelName("googleai/gemini-3-flash-preview"), +// ai.WithMessages(sess.Messages()...), +// ) { +// if err != nil { +// return err +// } +// if result.Done { +// lastMessage = result.Response.Message +// sess.AddMessages(lastMessage) +// } else { +// resp.SendModelChunk(result.Chunk) +// } +// } +// return nil +// }) +// if err != nil { +// return nil, err +// } +// return &aix.AgentFlowResult{Message: lastMessage}, nil +// }, +// ) // -// type ChatStatus struct { -// Phase string `json:"phase"` +// // Start a conversation: +// conn, err := chatAgent.StreamBidi(ctx) +// if err != nil { +// // handle error // } // -// chatFlow := genkit.DefineCustomAgent(g, "chatFlow", -// func(ctx context.Context, resp aix.Responder[ChatStatus], sess *aix.AgentSession[ChatState]) error { -// return sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { -// // ... handle each turn ... -// return nil -// }) -// }, -// aix.WithSessionStore(store), -// ) +// // Send a message and stream the response: +// conn.SendText("Hello!") +// for chunk, err := range conn.Receive() { +// if chunk.EndTurn { +// break +// } +// fmt.Print(chunk.ModelChunk.Text()) +// } +// conn.Close() func DefineCustomAgent[Stream, State any]( g *Genkit, name string, @@ -447,24 +490,69 @@ func DefineCustomAgent[Stream, State any]( return aix.DefineCustomAgent(g.reg, name, fn, opts...) } -// DefinePromptAgent creates a prompt-backed AgentFlow with an -// automatic conversation loop. Each turn renders the prompt, appends -// conversation history, calls the model with streaming, and updates session state. +// DefinePromptAgent defines a prompt-backed agent flow, registers it as a +// [core.Action] of type Flow, and returns an [aix.AgentFlow]. +// +// This is a higher-level alternative to [DefineCustomAgent] for agents backed +// by a prompt (defined via [DefinePrompt] or loaded from a .prompt file). The +// conversation loop is handled automatically: each turn renders the prompt, +// appends conversation history, calls the model with streaming, and updates +// session state. // -// The defaultInput is used for prompt rendering unless overridden per -// invocation via [aix.WithInputVariables]. +// The prompt is looked up by promptName from the registry. The defaultInput +// provides template variables for prompt rendering (e.g., personality, tone) +// and can be overridden per invocation via [aix.WithInputVariables]. +// +// DefinePromptAgent accepts the same options as [DefineCustomAgent]. See +// [DefineCustomAgent] for available options. // // Type parameters: -// - State: Type for user-defined state in snapshots -// - PromptIn: The prompt input type (inferred from the PromptRenderer) +// - State: Type for user-defined state persisted in snapshots +// - PromptIn: The prompt input type (inferred from defaultInput) +// +// Example: +// +// // Given a .prompt file "chat.prompt" loaded via WithPromptDir: +// // --- +// // model: googleai/gemini-3-flash-preview +// // input: +// // schema: +// // personality: string +// // --- +// // {{role "system"}} +// // You are {{personality}}. +// +// type ChatInput struct { +// Personality string `json:"personality"` +// } +// +// chatAgent := genkit.DefinePromptAgent(g, "chat", +// ChatInput{Personality: "a helpful assistant"}, +// aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), +// ) +// +// // Start a conversation: +// conn, err := chatAgent.StreamBidi(ctx) +// if err != nil { +// // handle error +// } +// +// // Send a message and stream the response: +// conn.SendText("Hello!") +// for chunk, err := range conn.Receive() { +// if chunk.EndTurn { +// break +// } +// fmt.Print(chunk.ModelChunk.Text()) +// } +// conn.Close() func DefinePromptAgent[State, PromptIn any]( g *Genkit, - name string, - p aix.PromptRenderer[PromptIn], + promptName string, defaultInput PromptIn, opts ...aix.AgentFlowOption[State], ) *aix.AgentFlow[any, State] { - return aix.DefinePromptAgent(g.reg, name, p, defaultInput, opts...) + return aix.DefinePromptAgent(g.reg, promptName, defaultInput, opts...) } // Run executes the given function `fn` within the context of the current flow run, @@ -847,7 +935,7 @@ func LookupTool(g *Genkit, name string) ai.Tool { // // Define the prompt // capitalPrompt := genkit.DefinePrompt(g, "findCapital", // ai.WithDescription("Finds the capital of a country."), -// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithSystem("You are a helpful geography assistant."), // ai.WithPrompt("What is the capital of {{country}}?"), // ai.WithInputType(GeoInput{Country: "USA"}), @@ -956,7 +1044,7 @@ func DefineSchemaFor[T any](g *Genkit) { // } // // capitalPrompt := genkit.DefineDataPrompt[GeoInput, GeoOutput](g, "findCapital", -// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithSystem("You are a helpful geography assistant."), // ai.WithPrompt("What is the capital of {{country}}?"), // ) @@ -1013,7 +1101,7 @@ func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.Generate // // Model and Configuration: // - [ai.WithModel]: Specify the model (accepts [ai.Model] or [ai.ModelRef]) -// - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-2.5-flash") +// - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-3-flash-preview") // - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) // // Prompting: @@ -1051,7 +1139,7 @@ func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.Generate // Example: // // resp, err := genkit.Generate(ctx, g, -// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithPrompt("Write a short poem about clouds."), // ) // if err != nil { @@ -1538,7 +1626,7 @@ func LoadPrompt(g *Genkit, path, namespace string) ai.Prompt { // Example: // // promptSource := `--- -// model: googleai/gemini-2.5-flash +// model: googleai/gemini-3-flash-preview // input: // schema: // name: string diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index dd45cc2f16..dd88f3ea30 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -39,10 +39,8 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatPrompt := genkit.LookupDataPrompt[ChatPromptInput, string](g, "chat") - chatFlow := genkit.DefinePromptAgent( - g, "chat", chatPrompt, ChatPromptInput{Personality: "a sarcastic pirate"}, + g, "chat", ChatPromptInput{Personality: "a sarcastic pirate"}, aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 From b46acce20892e3ad707553dcb9ae3883f809843d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 20 Feb 2026 17:39:22 -0800 Subject: [PATCH 034/141] fixed interrupts streaming --- go/ai/generate.go | 11 +++++++++++ go/ai/x/agent.go | 16 +++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index ee91edbb5b..c641012609 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -1017,6 +1017,17 @@ func (c *ModelResponseChunk) Reasoning() string { return sb.String() } +// Interrupts returns the interrupted tool request parts from the chunk. +func (c *ModelResponseChunk) Interrupts() []*Part { + var parts []*Part + for _, p := range c.Content { + if p.IsInterrupt() { + parts = append(parts, p) + } + } + return parts +} + // Output parses the chunk using the format handler and unmarshals the result into v. // Returns an error if the format handler is not set or does not support parsing chunks. func (c *ModelResponseChunk) Output(v any) error { diff --git a/go/ai/x/agent.go b/go/ai/x/agent.go index 5775b48f00..69a6e6c20a 100644 --- a/go/ai/x/agent.go +++ b/go/ai/x/agent.go @@ -394,13 +394,15 @@ func DefinePromptAgent[State, PromptIn any]( sess.AddMessages(modelResp.Message) } - // If generation was interrupted, stream the interrupted message - // so the client can see the tool request parts with interrupt metadata. - if modelResp.FinishReason == ai.FinishReasonInterrupted && modelResp.Message != nil { - resp.SendModelChunk(&ai.ModelResponseChunk{ - Content: modelResp.Message.Content, - Role: modelResp.Message.Role, - }) + // Stream interrupt parts so the client can detect and + // handle them (e.g. prompt the user for confirmation). + if modelResp.FinishReason == ai.FinishReasonInterrupted { + if parts := modelResp.Interrupts(); len(parts) > 0 { + resp.SendModelChunk(&ai.ModelResponseChunk{ + Role: ai.RoleTool, + Content: parts, + }) + } } return nil From 6676882eed11d7ce3d6c84cf843d7d29828ec1da Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 20 Feb 2026 17:55:45 -0800 Subject: [PATCH 035/141] Update genkit.go --- go/genkit/genkit.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 0455eac406..84b581826f 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -412,6 +412,9 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B // conversation loop, registers it as a [core.Action] of type Flow, and // returns an [aix.AgentFlow]. // +// Experimental: This API is under active development and may change in any +// minor version release. +// // An AgentFlow is a stateful, multi-turn conversational flow. It builds on // bidirectional streaming to enable ongoing conversations where each turn's // input and output are streamed between client and server. The framework @@ -493,6 +496,9 @@ func DefineCustomAgent[Stream, State any]( // DefinePromptAgent defines a prompt-backed agent flow, registers it as a // [core.Action] of type Flow, and returns an [aix.AgentFlow]. // +// Experimental: This API is under active development and may change in any +// minor version release. +// // This is a higher-level alternative to [DefineCustomAgent] for agents backed // by a prompt (defined via [DefinePrompt] or loaded from a .prompt file). The // conversation loop is handled automatically: each turn renders the prompt, From 8d99639f106381b190c49038f60a0d54662fcd3d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 20 Feb 2026 18:59:03 -0800 Subject: [PATCH 036/141] added `SetMessages` --- go/ai/x/agent.go | 2 +- go/ai/x/agent_test.go | 4 +--- go/ai/x/session.go | 7 +++++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/go/ai/x/agent.go b/go/ai/x/agent.go index 69a6e6c20a..1ce1f2ddc1 100644 --- a/go/ai/x/agent.go +++ b/go/ai/x/agent.go @@ -389,7 +389,7 @@ func DefinePromptAgent[State, PromptIn any]( } msgs = append(msgs, m) } - sess.UpdateMessages(func(_ []*ai.Message) []*ai.Message { return msgs }) + sess.SetMessages(msgs) } else if modelResp.Message != nil { sess.AddMessages(modelResp.Message) } diff --git a/go/ai/x/agent_test.go b/go/ai/x/agent_test.go index 5edd805f4f..ba6f3a2edf 100644 --- a/go/ai/x/agent_test.go +++ b/go/ai/x/agent_test.go @@ -593,9 +593,7 @@ func TestAgentFlow_SetMessages(t *testing.T) { func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Replace all messages with just one. - sess.UpdateMessages(func(_ []*ai.Message) []*ai.Message { - return []*ai.Message{ai.NewModelTextMessage("replaced")} - }) + sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) return nil }) }, diff --git a/go/ai/x/session.go b/go/ai/x/session.go index a4418532e8..a7c9727ff8 100644 --- a/go/ai/x/session.go +++ b/go/ai/x/session.go @@ -162,6 +162,13 @@ func (s *Session[State]) AddMessages(messages ...*ai.Message) { s.state.Messages = append(s.state.Messages, messages...) } +// SetMessages replaces the conversation history with the given messages. +func (s *Session[State]) SetMessages(messages []*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = messages +} + // UpdateMessages atomically reads the current messages, applies the given // function, and writes the result back. func (s *Session[State]) UpdateMessages(fn func([]*ai.Message) []*ai.Message) { From 3cbec28b30f463716dcdfaa741d42802c05f5c9f Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Feb 2026 18:20:18 -0800 Subject: [PATCH 037/141] added `AgentSession.Result()` --- go/ai/x/agent.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/go/ai/x/agent.go b/go/ai/x/agent.go index 1ce1f2ddc1..70c55a8e71 100644 --- a/go/ai/x/agent.go +++ b/go/ai/x/agent.go @@ -96,6 +96,26 @@ func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Conte return nil } +// Result returns an [AgentFlowResult] populated from the current session state: +// the last message in the conversation history and all artifacts. +// It is a convenience for custom agent flows that don't need to construct the +// result manually. +func (a *AgentSession[State]) Result() *AgentFlowResult { + a.mu.RLock() + defer a.mu.RUnlock() + + result := &AgentFlowResult{} + if msgs := a.state.Messages; len(msgs) > 0 { + result.Message = msgs[len(msgs)-1] + } + if len(a.state.Artifacts) > 0 { + arts := make([]*Artifact, len(a.state.Artifacts)) + copy(arts, a.state.Artifacts) + result.Artifacts = arts + } + return result +} + // maybeSnapshot creates a snapshot if conditions are met (store configured, // callback approves). Returns the snapshot ID or empty string. func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { From be5564c468b0ae1d68f82a798ff262661574a1d0 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 07:46:38 -0800 Subject: [PATCH 038/141] fixed types --- go/core/action.go | 4 ++-- go/core/flow.go | 20 ++++++++++---------- go/genkit/x/genkit.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index b43a60dbc3..0bd867b4f1 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -257,8 +257,8 @@ func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb St return r.Result, nil } -// Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { +// runWithTelemetry executes the Action's function in a new trace span and returns telemetry info. +func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { logger.FromContext(ctx).Debug("Action.Run", "name", a.Name()) defer func() { logger.FromContext(ctx).Debug("Action.Run", diff --git a/go/core/flow.go b/go/core/flow.go index 6f0351a497..3ad8bba60d 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -48,8 +48,8 @@ type flowContext struct { } // NewFlow creates a Flow that runs fn without registering it. fn takes an input of type In and returns an output of type Out. -func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { - return (*Flow[In, Out, struct{}])(NewAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { +func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { + return &Flow[In, Out, struct{}, struct{}]{NewAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { fc := &flowContext{ flowName: name, } @@ -59,8 +59,8 @@ func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{} } // NewStreamingFlow creates a streaming Flow that runs fn without registering it. -func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] { - return (*Flow[In, Out, Stream])(NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { +func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { + return &Flow[In, Out, Stream, struct{}]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { fc := &flowContext{ flowName: name, } @@ -69,7 +69,7 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out cb = func(context.Context, Stream) error { return nil } } return fn(ctx, input, cb) - })) + })} } // NewBidiFlow creates a bidirectional streaming Flow without registering it. @@ -83,7 +83,7 @@ func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, St } // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. -func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { +func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { f := NewFlow(name, fn) f.Register(r) return f @@ -99,9 +99,9 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // with a final return value that includes all the streamed data. // Otherwise, it should ignore the callback and just return a result. func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { - f := NewStreamingFlow(name, fn) - f.Register(r) - return f + f := NewStreamingFlow(name, fn) + f.Register(r) + return f } // DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. @@ -167,7 +167,7 @@ func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func } return nil } - output, err := (*ActionDef[In, Out, Stream])(f).Run(ctx, input, cb) + output, err := f.Action.Run(ctx, input, cb) if errors.Is(err, errStop) { // Consumer broke out of the loop; don't yield again. return diff --git a/go/genkit/x/genkit.go b/go/genkit/x/genkit.go index 9480f75149..8574d338ea 100644 --- a/go/genkit/x/genkit.go +++ b/go/genkit/x/genkit.go @@ -82,7 +82,7 @@ type StreamingFunc[In, Out, Stream any] = func(ctx context.Context, input In, st // fmt.Println(val.Stream) // 5, 4, 3, 2, 1 // } // } -func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { +func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream, struct{}] { // Wrap the channel-based function to work with the callback-based API wrappedFn := func(ctx context.Context, input In, sendChunk core.StreamCallback[Stream]) (Out, error) { if sendChunk == nil { From f943c96068b4afd1b7583403df8e762272cc1e0e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 07:46:38 -0800 Subject: [PATCH 039/141] fixed types --- go/core/action.go | 4 ++-- go/core/flow.go | 20 ++++++++++---------- go/genkit/x/genkit.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index b43a60dbc3..0bd867b4f1 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -257,8 +257,8 @@ func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb St return r.Result, nil } -// Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { +// runWithTelemetry executes the Action's function in a new trace span and returns telemetry info. +func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { logger.FromContext(ctx).Debug("Action.Run", "name", a.Name()) defer func() { logger.FromContext(ctx).Debug("Action.Run", diff --git a/go/core/flow.go b/go/core/flow.go index 6f0351a497..3ad8bba60d 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -48,8 +48,8 @@ type flowContext struct { } // NewFlow creates a Flow that runs fn without registering it. fn takes an input of type In and returns an output of type Out. -func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { - return (*Flow[In, Out, struct{}])(NewAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { +func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { + return &Flow[In, Out, struct{}, struct{}]{NewAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { fc := &flowContext{ flowName: name, } @@ -59,8 +59,8 @@ func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{} } // NewStreamingFlow creates a streaming Flow that runs fn without registering it. -func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] { - return (*Flow[In, Out, Stream])(NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { +func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { + return &Flow[In, Out, Stream, struct{}]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { fc := &flowContext{ flowName: name, } @@ -69,7 +69,7 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out cb = func(context.Context, Stream) error { return nil } } return fn(ctx, input, cb) - })) + })} } // NewBidiFlow creates a bidirectional streaming Flow without registering it. @@ -83,7 +83,7 @@ func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, St } // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. -func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { +func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { f := NewFlow(name, fn) f.Register(r) return f @@ -99,9 +99,9 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // with a final return value that includes all the streamed data. // Otherwise, it should ignore the callback and just return a result. func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { - f := NewStreamingFlow(name, fn) - f.Register(r) - return f + f := NewStreamingFlow(name, fn) + f.Register(r) + return f } // DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. @@ -167,7 +167,7 @@ func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func } return nil } - output, err := (*ActionDef[In, Out, Stream])(f).Run(ctx, input, cb) + output, err := f.Action.Run(ctx, input, cb) if errors.Is(err, errStop) { // Consumer broke out of the loop; don't yield again. return diff --git a/go/genkit/x/genkit.go b/go/genkit/x/genkit.go index 9480f75149..8574d338ea 100644 --- a/go/genkit/x/genkit.go +++ b/go/genkit/x/genkit.go @@ -82,7 +82,7 @@ type StreamingFunc[In, Out, Stream any] = func(ctx context.Context, input In, st // fmt.Println(val.Stream) // 5, 4, 3, 2, 1 // } // } -func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { +func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream, struct{}] { // Wrap the channel-based function to work with the callback-based API wrappedFn := func(ctx context.Context, input In, sendChunk core.StreamCallback[Stream]) (Out, error) { if sendChunk == nil { From fa0430425c6f857424a1ce6f8addc92610d15af8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 08:26:30 -0800 Subject: [PATCH 040/141] moved from `ai/x` to `ai/exp` --- go/ai/{x => exp}/agent.go | 4 ++-- go/ai/{x => exp}/agent_test.go | 2 +- go/ai/{x => exp}/gen.go | 2 +- go/ai/{x => exp}/option.go | 2 +- go/ai/{x => exp}/session.go | 2 +- go/genkit/genkit.go | 2 +- go/samples/custom-agent/main.go | 13 +++++-------- go/samples/prompt-agent/main.go | 2 +- 8 files changed, 13 insertions(+), 16 deletions(-) rename go/ai/{x => exp}/agent.go (99%) rename go/ai/{x => exp}/agent_test.go (99%) rename go/ai/{x => exp}/gen.go (99%) rename go/ai/{x => exp}/option.go (99%) rename go/ai/{x => exp}/session.go (99%) diff --git a/go/ai/x/agent.go b/go/ai/exp/agent.go similarity index 99% rename from go/ai/x/agent.go rename to go/ai/exp/agent.go index 70c55a8e71..9736036968 100644 --- a/go/ai/x/agent.go +++ b/go/ai/exp/agent.go @@ -14,11 +14,11 @@ // // SPDX-License-Identifier: Apache-2.0 -// Package aix provides experimental AI primitives for Genkit. +// Package exp provides experimental AI primitives for Genkit. // // APIs in this package are under active development and may change in any // minor version release. -package aix +package exp import ( "context" diff --git a/go/ai/x/agent_test.go b/go/ai/exp/agent_test.go similarity index 99% rename from go/ai/x/agent_test.go rename to go/ai/exp/agent_test.go index ba6f3a2edf..972ac5fccf 100644 --- a/go/ai/x/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -14,7 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -package aix +package exp import ( "context" diff --git a/go/ai/x/gen.go b/go/ai/exp/gen.go similarity index 99% rename from go/ai/x/gen.go rename to go/ai/exp/gen.go index ff5ace986c..b13c62aa86 100644 --- a/go/ai/x/gen.go +++ b/go/ai/exp/gen.go @@ -16,7 +16,7 @@ // This file was generated by jsonschemagen. DO NOT EDIT. -package aix +package exp import ( "github.com/firebase/genkit/go/ai" diff --git a/go/ai/x/option.go b/go/ai/exp/option.go similarity index 99% rename from go/ai/x/option.go rename to go/ai/exp/option.go index 9ef6152566..a80a1813eb 100644 --- a/go/ai/x/option.go +++ b/go/ai/exp/option.go @@ -14,7 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -package aix +package exp import ( "context" diff --git a/go/ai/x/session.go b/go/ai/exp/session.go similarity index 99% rename from go/ai/x/session.go rename to go/ai/exp/session.go index a7c9727ff8..3d26b12a3c 100644 --- a/go/ai/x/session.go +++ b/go/ai/exp/session.go @@ -14,7 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -package aix +package exp import ( "context" diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 84b581826f..21f312de6a 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -30,7 +30,7 @@ import ( "syscall" "github.com/firebase/genkit/go/ai" - aix "github.com/firebase/genkit/go/ai/x" + aix "github.com/firebase/genkit/go/ai/exp" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index 0fdde46871..d2cf7a63bc 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -25,7 +25,7 @@ import ( "strings" "github.com/firebase/genkit/go/ai" - aix "github.com/firebase/genkit/go/ai/x" + aix "github.com/firebase/genkit/go/ai/exp" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" "google.golang.org/genai" @@ -37,8 +37,7 @@ func main() { chatFlow := genkit.DefineCustomAgent(g, "chat", func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentFlowResult, error) { - var lastMessage *ai.Message - err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ @@ -52,19 +51,17 @@ func main() { return err } if chunk.Done { - lastMessage = chunk.Response.Message - sess.AddMessages(lastMessage) + sess.AddMessages(chunk.Response.Message) break } resp.SendModelChunk(chunk.Chunk) } return nil - }) - if err != nil { + }); err != nil { return nil, err } - return &aix.AgentFlowResult{Message: lastMessage}, nil + return sess.Result(), nil }, aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index dd88f3ea30..0f73c52c82 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -26,7 +26,7 @@ import ( "os" "strings" - aix "github.com/firebase/genkit/go/ai/x" + aix "github.com/firebase/genkit/go/ai/exp" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" ) From ee4e68a7a35e45b28443ae4bacbc033d1ece47c1 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 12:51:53 -0800 Subject: [PATCH 041/141] Update agent.go --- go/ai/exp/agent.go | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 9736036968..a5dd5000a5 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -344,8 +344,7 @@ func DefinePromptAgent[State, PromptIn any]( } fn := func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) (*AgentFlowResult, error) { - var lastModelMessage *ai.Message - err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { // Resolve prompt input: session state override > default. promptInput := defaultInput if stored := sess.InputVariables(); stored != nil { @@ -396,8 +395,6 @@ func DefinePromptAgent[State, PromptIn any]( return fmt.Errorf("generate: %w", err) } - lastModelMessage = modelResp.Message - // Replace session messages with the full history minus prompt // messages. This captures intermediate tool call/response // messages from the tool loop, not just the final response. @@ -426,14 +423,10 @@ func DefinePromptAgent[State, PromptIn any]( } return nil - }) - if err != nil { + }); err != nil { return nil, err } - return &AgentFlowResult{ - Message: lastModelMessage, - Artifacts: sess.Artifacts(), - }, nil + return sess.Result(), nil } return DefineCustomAgent(r, promptName, fn, opts...) @@ -482,7 +475,7 @@ func newSessionFromInit[State any]( var snapshot *SessionSnapshot[State] if init != nil { if init.SnapshotID != "" && init.State != nil { - return nil, nil, core.NewError(core.INVALID_ARGUMENT, "snapshotId and state are mutually exclusive") + return nil, nil, core.NewError(core.INVALID_ARGUMENT, "snapshot ID and state are mutually exclusive") } if init.SnapshotID != "" && store == nil { return nil, nil, core.NewError(core.FAILED_PRECONDITION, "snapshot ID %q provided but no session store configured", init.SnapshotID) @@ -512,7 +505,6 @@ func newSessionFromInit[State any]( // of the iterator between turns does not cancel the underlying connection. type AgentFlowConnection[Stream, State any] struct { conn *core.BidiConnection[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream]] - // chunks buffers stream chunks from the underlying connection so that // breaking from Receive() between turns doesn't cancel the context. chunks chan *AgentFlowStreamChunk[Stream] From 975accd2eb74826f7124d18858ce286eaf935865 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 13:00:21 -0800 Subject: [PATCH 042/141] added `AgentFlow.Run()` and `AgentFlow.RunText()` --- go/ai/exp/agent.go | 85 ++++++++++++--- go/ai/exp/agent_test.go | 178 ++++++++++++++++++++++++++++++++ go/ai/exp/option.go | 24 ++--- go/samples/custom-agent/main.go | 1 + 4 files changed, 261 insertions(+), 27 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index a5dd5000a5..2d9b721e6a 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -432,35 +432,90 @@ func DefinePromptAgent[State, PromptIn any]( return DefineCustomAgent(r, promptName, fn, opts...) } -// StreamBidi starts a new agent flow invocation. +// StreamBidi starts a new agent flow invocation with bidirectional streaming. +// Use this for multi-turn interactions where you need to send multiple inputs +// and receive streaming chunks. For single-turn usage, see Run and RunText. func (af *AgentFlow[Stream, State]) StreamBidi( ctx context.Context, - opts ...StreamBidiOption[State], + opts ...InvocationOption[State], ) (*AgentFlowConnection[Stream, State], error) { - sbOpts := &streamBidiOptions[State]{} + invOpts, err := af.resolveOptions(opts) + if err != nil { + return nil, err + } + + conn, err := af.flow.StreamBidi(ctx, invOpts) + if err != nil { + return nil, err + } + + return &AgentFlowConnection[Stream, State]{conn: conn}, nil +} + +// Run starts a single-turn agent flow invocation with the given input. +// It sends the input, waits for the flow to complete, and returns the output. +// For multi-turn interactions or streaming, use StreamBidi instead. +func (af *AgentFlow[Stream, State]) Run( + ctx context.Context, + input *AgentFlowInput, + opts ...InvocationOption[State], +) (*AgentFlowOutput[State], error) { + conn, err := af.StreamBidi(ctx, opts...) + if err != nil { + return nil, err + } + + if err := conn.Send(input); err != nil { + return nil, err + } + if err := conn.Close(); err != nil { + return nil, err + } + + // Drain stream chunks. + for _, err := range conn.Receive() { + if err != nil { + return nil, err + } + } + + return conn.Output() +} + +// RunText is a convenience method that starts a single-turn agent flow +// invocation with a user text message. It is equivalent to calling Run with +// an AgentFlowInput containing a single user text message. +func (af *AgentFlow[Stream, State]) RunText( + ctx context.Context, + text string, + opts ...InvocationOption[State], +) (*AgentFlowOutput[State], error) { + return af.Run(ctx, &AgentFlowInput{ + Messages: []*ai.Message{ai.NewUserTextMessage(text)}, + }, opts...) +} + +// resolveOptions applies invocation options and returns the init struct. +func (af *AgentFlow[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*AgentFlowInit[State], error) { + invOpts := &invocationOptions[State]{} for _, opt := range opts { - if err := opt.applyStreamBidi(sbOpts); err != nil { - return nil, fmt.Errorf("AgentFlow.StreamBidi %q: %w", af.flow.Name(), err) + if err := opt.applyInvocation(invOpts); err != nil { + return nil, fmt.Errorf("AgentFlow %q: %w", af.flow.Name(), err) } } init := &AgentFlowInit[State]{ - SnapshotID: sbOpts.snapshotID, - State: sbOpts.state, + SnapshotID: invOpts.snapshotID, + State: invOpts.state, } - if sbOpts.promptInput != nil { + if invOpts.promptInput != nil { if init.State == nil { init.State = &SessionState[State]{} } - init.State.InputVariables = sbOpts.promptInput + init.State.InputVariables = invOpts.promptInput } - conn, err := af.flow.StreamBidi(ctx, init) - if err != nil { - return nil, err - } - - return &AgentFlowConnection[Stream, State]{conn: conn}, nil + return init, nil } // newSessionFromInit creates a Session from initialization data. diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 972ac5fccf..152d2807cc 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -1331,3 +1331,181 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { t.Errorf("msg[3] should be final model response, got role=%s text=%s", msgs[3].Role, msgs[3].Text()) } } + +func TestAgentFlow_RunText(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "runTextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text)) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + ) + + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + // 1 user message + 1 echo reply = 2. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + } + if got := response.State.Custom.Counter; got != 1 { + t.Errorf("expected counter=1, got %d", got) + } +} + +func TestAgentFlow_Run(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "runFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + return nil + }) + }, + ) + + input := &AgentFlowInput{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("msg1"), + ai.NewUserTextMessage("msg2"), + }, + } + + response, err := af.Run(ctx, input) + if err != nil { + t.Fatalf("Run failed: %v", err) + } + + // 2 user messages + 1 reply = 3. + if got := len(response.State.Messages); got != 3 { + t.Errorf("expected 3 messages, got %d", got) + } +} + +func TestAgentFlow_RunText_WithState(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "runStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + ) + + clientState := &SessionState[testState]{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("previous"), + ai.NewModelTextMessage("previous reply"), + }, + Custom: testState{Counter: 10}, + } + + response, err := af.RunText(ctx, "new message", WithState(clientState)) + if err != nil { + t.Fatalf("RunText with state failed: %v", err) + } + + // 2 previous + 1 new user + 1 reply = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + // Counter should be 11 (started at 10, incremented once). + if got := response.State.Custom.Counter; got != 11 { + t.Errorf("expected counter=11, got %d", got) + } +} + +func TestAgentFlow_RunText_WithSnapshot(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "runSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + // First invocation via RunText. + resp1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("first RunText failed: %v", err) + } + if resp1.SnapshotID == "" { + t.Fatal("expected snapshot ID from first invocation") + } + + // Resume from snapshot via RunText. + resp2, err := af.RunText(ctx, "second", WithSnapshotID[testState](resp1.SnapshotID)) + if err != nil { + t.Fatalf("second RunText failed: %v", err) + } + + snap, err := store.GetSnapshot(ctx, resp2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + // 4 messages: first user + reply + second user + reply. + if got := len(snap.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + if got := snap.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } +} + +func TestPromptAgent_RunText(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + ai.DefinePrompt(reg, "runTextPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + af := DefinePromptAgent[testState, any](reg, "runTextPrompt", nil) + + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + // 1 user message + 1 model reply = 2. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index a80a1813eb..f7e80639ca 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -74,20 +74,20 @@ func WithSnapshotOn[State any](events ...SnapshotEvent) AgentFlowOption[State] { }) } -// --- StreamBidiOption --- +// --- InvocationOption --- -// StreamBidiOption configures a StreamBidi call. -type StreamBidiOption[State any] interface { - applyStreamBidi(*streamBidiOptions[State]) error +// InvocationOption configures an agent flow invocation (StreamBidi, Run, or RunText). +type InvocationOption[State any] interface { + applyInvocation(*invocationOptions[State]) error } -type streamBidiOptions[State any] struct { +type invocationOptions[State any] struct { state *SessionState[State] snapshotID string promptInput any } -func (o *streamBidiOptions[State]) applyStreamBidi(opts *streamBidiOptions[State]) error { +func (o *invocationOptions[State]) applyInvocation(opts *invocationOptions[State]) error { if o.state != nil { if opts.state != nil { return errors.New("cannot set state more than once (WithState)") @@ -117,18 +117,18 @@ func (o *streamBidiOptions[State]) applyStreamBidi(opts *streamBidiOptions[State // WithState sets the initial state for the invocation. // Use this for client-managed state where the client sends state directly. -func WithState[State any](state *SessionState[State]) StreamBidiOption[State] { - return &streamBidiOptions[State]{state: state} +func WithState[State any](state *SessionState[State]) InvocationOption[State] { + return &invocationOptions[State]{state: state} } // WithSnapshotID loads state from a persisted snapshot by ID. // Use this for server-managed state where snapshots are stored. -func WithSnapshotID[State any](id string) StreamBidiOption[State] { - return &streamBidiOptions[State]{snapshotID: id} +func WithSnapshotID[State any](id string) InvocationOption[State] { + return &invocationOptions[State]{snapshotID: id} } // WithInputVariables overrides the default input variables for a prompt-backed agent flow. // Used with DefinePromptAgent to customize the input variables per invocation. -func WithInputVariables[State any](input any) StreamBidiOption[State] { - return &streamBidiOptions[State]{promptInput: input} +func WithInputVariables[State any](input any) InvocationOption[State] { + return &invocationOptions[State]{promptInput: input} } diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index d2cf7a63bc..a092e88ad2 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -116,4 +116,5 @@ func main() { } conn.Close() + fmt.Println(conn.Output()) } From a6f044b9162723cf6a2c9cbbfa249b4f5bd3f6bf Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 13:18:32 -0800 Subject: [PATCH 043/141] dedupe consecutive identical snapshots --- go/ai/exp/agent.go | 28 +++++-- go/ai/exp/agent_test.go | 157 ++++++++++++++++++++++++++++++++++++++++ go/ai/exp/session.go | 13 +++- 3 files changed, 189 insertions(+), 9 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 2d9b721e6a..c25d46e5aa 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -54,10 +54,11 @@ type AgentSession[State any] struct { // directly. TurnIndex int - snapshotCallback SnapshotCallback[State] - onEndTurn func(ctx context.Context) - lastSnapshot *SessionSnapshot[State] - collectTurnOutput func() any + snapshotCallback SnapshotCallback[State] + onEndTurn func(ctx context.Context) + lastSnapshot *SessionSnapshot[State] + lastSnapshotVersion uint64 + collectTurnOutput func() any } // Run loops over the input channel, calling fn for each turn. Each turn is @@ -117,16 +118,25 @@ func (a *AgentSession[State]) Result() *AgentFlowResult { } // maybeSnapshot creates a snapshot if conditions are met (store configured, -// callback approves). Returns the snapshot ID or empty string. +// callback approves, state changed). Returns the snapshot ID or empty string. func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { if a.store == nil { return "" } a.mu.RLock() + currentVersion := a.version currentState := a.copyStateLocked() a.mu.RUnlock() + // Skip if state hasn't changed since the last snapshot. This avoids + // redundant snapshots, e.g. the invocation-end snapshot after a + // single-turn Run where the turn-end snapshot already captured the + // same state. + if a.lastSnapshot != nil && currentVersion == a.lastSnapshotVersion { + return "" + } + if a.snapshotCallback != nil { var prevState *SessionState[State] if a.lastSnapshot != nil { @@ -169,6 +179,7 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE a.mu.Unlock() a.lastSnapshot = snapshot + a.lastSnapshotVersion = currentVersion return snapshot.SnapshotID } @@ -298,8 +309,13 @@ func DefineCustomAgent[Stream, State any]( return nil, fnErr } - // Final snapshot at invocation end. + // Final snapshot at invocation end. If skipped (state unchanged + // since last turn-end snapshot), use the last snapshot's ID so + // the output always reflects the latest snapshot. snapshotID := agentSess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + if snapshotID == "" && agentSess.lastSnapshot != nil { + snapshotID = agentSess.lastSnapshot.SnapshotID + } out := &AgentFlowOutput[State]{ SnapshotID: snapshotID, diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 152d2807cc..d2d75c232f 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -1509,3 +1509,160 @@ func TestPromptAgent_RunText(t *testing.T) { } } } + +func TestAgentFlow_SingleTurnSnapshotDedup(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "dedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + // Single-turn invocation: should produce exactly 1 snapshot (turn-end), + // not 2 (turn-end + invocation-end with identical state). + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + if response.SnapshotID == "" { + t.Fatal("expected snapshot ID in response") + } + + // Count total snapshots in the store. + snap, err := store.GetSnapshot(ctx, response.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.Event != SnapshotEventTurnEnd { + t.Errorf("expected turn-end snapshot, got %s", snap.Event) + } + // The turn-end snapshot should have no parent (first and only snapshot). + if snap.ParentID != "" { + t.Errorf("expected no parent (single snapshot), got parent %q", snap.ParentID) + } +} + +func TestAgentFlow_MultiTurnSnapshotDedup(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "multiDedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + // Multi-turn: last turn-end snapshot should dedup with invocation-end. + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + var snapshotIDs []string + for i := 0; i < 3; i++ { + conn.SendText(fmt.Sprintf("turn %d", i)) + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error on turn %d: %v", i, err) + } + if chunk.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotID) + } + if chunk.EndTurn { + break + } + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have 3 turn-end snapshots (one per turn), no extra invocation-end. + if got := len(snapshotIDs); got != 3 { + t.Errorf("expected 3 turn-end snapshots, got %d", got) + } + + // The output snapshot ID should reuse the last turn-end snapshot. + if response.SnapshotID == "" { + t.Fatal("expected snapshot ID in response") + } + if response.SnapshotID != snapshotIDs[len(snapshotIDs)-1] { + t.Errorf("expected output snapshot to reuse last turn-end snapshot %q, got %q", + snapshotIDs[len(snapshotIDs)-1], response.SnapshotID) + } +} + +func TestAgentFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "postRunMutateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }); err != nil { + return nil, err + } + // Mutate state AFTER sess.Run returns -- this should trigger + // a separate invocation-end snapshot. + sess.UpdateCustom(func(s testState) testState { + s.Counter = 99 + return s + }) + return sess.Result(), nil + }, + WithSessionStore(store), + ) + + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + if response.SnapshotID == "" { + t.Fatal("expected snapshot ID in response") + } + + // The final snapshot should be an invocation-end snapshot that captured + // the post-Run mutation. + snap, err := store.GetSnapshot(ctx, response.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.Event != SnapshotEventInvocationEnd { + t.Errorf("expected invocation-end snapshot, got %s", snap.Event) + } + if snap.State.Custom.Counter != 99 { + t.Errorf("expected counter=99 in final snapshot, got %d", snap.State.Custom.Counter) + } + // Should have a parent (the turn-end snapshot). + if snap.ParentID == "" { + t.Error("expected parent ID (turn-end snapshot)") + } +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 3d26b12a3c..40f0d6dfc8 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -133,9 +133,10 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta // Session holds conversation state and provides thread-safe read/write access to messages, // input variables, custom state, and artifacts. type Session[State any] struct { - mu sync.RWMutex - state SessionState[State] - store SessionStore[State] + mu sync.RWMutex + state SessionState[State] + store SessionStore[State] + version uint64 // incremented on every mutation; used to skip redundant snapshots } // State returns a copy of the current state. @@ -160,6 +161,7 @@ func (s *Session[State]) AddMessages(messages ...*ai.Message) { s.mu.Lock() defer s.mu.Unlock() s.state.Messages = append(s.state.Messages, messages...) + s.version++ } // SetMessages replaces the conversation history with the given messages. @@ -167,6 +169,7 @@ func (s *Session[State]) SetMessages(messages []*ai.Message) { s.mu.Lock() defer s.mu.Unlock() s.state.Messages = messages + s.version++ } // UpdateMessages atomically reads the current messages, applies the given @@ -175,6 +178,7 @@ func (s *Session[State]) UpdateMessages(fn func([]*ai.Message) []*ai.Message) { s.mu.Lock() defer s.mu.Unlock() s.state.Messages = fn(s.state.Messages) + s.version++ } // Custom returns the current user-defined custom state. @@ -190,6 +194,7 @@ func (s *Session[State]) UpdateCustom(fn func(State) State) { s.mu.Lock() defer s.mu.Unlock() s.state.Custom = fn(s.state.Custom) + s.version++ } // InputVariables returns the prompt input stored in the session state. @@ -228,6 +233,7 @@ func (s *Session[State]) AddArtifacts(artifacts ...*Artifact) { s.state.Artifacts = append(s.state.Artifacts, a) } } + s.version++ } // UpdateArtifacts atomically reads the current artifacts, applies the given @@ -236,6 +242,7 @@ func (s *Session[State]) UpdateArtifacts(fn func([]*Artifact) []*Artifact) { s.mu.Lock() defer s.mu.Unlock() s.state.Artifacts = fn(s.state.Artifacts) + s.version++ } // copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). From f00786a3c8be3ebb72a356069d407a94a1732ff9 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Mar 2026 09:57:01 -0800 Subject: [PATCH 044/141] renamed agent flow et al to session flow --- genkit-tools/common/src/types/agent.ts | 34 ++-- genkit-tools/genkit-schema.json | 10 +- go/ai/exp/agent.go | 190 +++++++++---------- go/ai/exp/agent_test.go | 184 +++++++++--------- go/ai/exp/gen.go | 28 +-- go/ai/exp/option.go | 28 +-- go/ai/exp/session.go | 4 +- go/core/api/action.go | 2 +- go/core/schemas.config | 100 +++++----- go/genkit/genkit.go | 50 ++--- go/samples/custom-agent/main.go | 10 +- go/samples/prompt-agent/main.go | 8 +- py/packages/genkit/src/genkit/core/typing.py | 20 +- 13 files changed, 334 insertions(+), 334 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index fb8c2ce9ec..e2f98aeb16 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -47,48 +47,48 @@ export const SessionStateSchema = z.object({ custom: z.any().optional(), /** Named collections of parts produced during the conversation. */ artifacts: z.array(ArtifactSchema).optional(), - /** Input used for agent flows that require input variables. */ + /** Input used for session flows that require input variables. */ inputVariables: z.any().optional(), }); export type SessionState = z.infer; /** - * Zod schema for agent flow input (per-turn). + * Zod schema for session flow input (per-turn). */ -export const AgentFlowInputSchema = z.object({ +export const SessionFlowInputSchema = z.object({ /** User's input messages for this turn. */ messages: z.array(MessageSchema).optional(), /** Tool request parts to re-execute interrupted tools. */ toolRestarts: z.array(PartSchema).optional(), }); -export type AgentFlowInput = z.infer; +export type SessionFlowInput = z.infer; /** - * Zod schema for agent flow initialization. + * Zod schema for session flow initialization. */ -export const AgentFlowInitSchema = z.object({ +export const SessionFlowInitSchema = z.object({ /** Loads state from a persisted snapshot. Mutually exclusive with state. */ snapshotId: z.string().optional(), /** Direct state for the invocation. Mutually exclusive with snapshotId. */ state: SessionStateSchema.optional(), }); -export type AgentFlowInit = z.infer; +export type SessionFlowInit = z.infer; /** - * Zod schema for agent flow result. + * Zod schema for session flow result. */ -export const AgentFlowResultSchema = z.object({ +export const SessionFlowResultSchema = z.object({ /** Last model response message from the conversation. */ message: MessageSchema.optional(), /** Artifacts produced during the session. */ artifacts: z.array(ArtifactSchema).optional(), }); -export type AgentFlowResult = z.infer; +export type SessionFlowResult = z.infer; /** - * Zod schema for agent flow output. + * Zod schema for session flow output. */ -export const AgentFlowOutputSchema = z.object({ +export const SessionFlowOutputSchema = z.object({ /** ID of the snapshot created at the end of this invocation. */ snapshotId: z.string().optional(), /** Final conversation state (only when client-managed). */ @@ -98,12 +98,12 @@ export const AgentFlowOutputSchema = z.object({ /** Artifacts produced during the session. */ artifacts: z.array(ArtifactSchema).optional(), }); -export type AgentFlowOutput = z.infer; +export type SessionFlowOutput = z.infer; /** - * Zod schema for agent flow stream chunk. + * Zod schema for session flow stream chunk. */ -export const AgentFlowStreamChunkSchema = z.object({ +export const SessionFlowStreamChunkSchema = z.object({ /** Generation tokens from the model. */ modelChunk: ModelResponseChunkSchema.optional(), /** User-defined structured status information. */ @@ -112,7 +112,7 @@ export const AgentFlowStreamChunkSchema = z.object({ artifact: ArtifactSchema.optional(), /** ID of a snapshot that was just persisted. */ snapshotId: z.string().optional(), - /** Signals that the agent flow has finished processing the current input. */ + /** Signals that the session flow has finished processing the current input. */ endTurn: z.boolean().optional(), }); -export type AgentFlowStreamChunk = z.infer; +export type SessionFlowStreamChunk = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index bae17a88ae..5dbbab6de0 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -1,7 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "$defs": { - "AgentFlowInit": { + "SessionFlowInit": { "type": "object", "properties": { "snapshotId": { @@ -13,7 +13,7 @@ }, "additionalProperties": false }, - "AgentFlowInput": { + "SessionFlowInput": { "type": "object", "properties": { "messages": { @@ -31,7 +31,7 @@ }, "additionalProperties": false }, - "AgentFlowOutput": { + "SessionFlowOutput": { "type": "object", "properties": { "snapshotId": { @@ -52,7 +52,7 @@ }, "additionalProperties": false }, - "AgentFlowResult": { + "SessionFlowResult": { "type": "object", "properties": { "message": { @@ -67,7 +67,7 @@ }, "additionalProperties": false }, - "AgentFlowStreamChunk": { + "SessionFlowStreamChunk": { "type": "object", "properties": { "modelChunk": { diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index c25d46e5aa..3b097c311c 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -36,20 +36,20 @@ import ( "github.com/google/uuid" ) -// --- AgentSession --- +// --- SessionRunner --- -// AgentSession extends Session with agent-flow-specific functionality: +// SessionRunner extends Session with session-flow-specific functionality: // turn management, snapshot persistence, and input channel handling. -type AgentSession[State any] struct { +type SessionRunner[State any] struct { *Session[State] // InputCh is the channel that delivers per-turn inputs from the client. - // It is consumed automatically by [AgentSession.Run], but is exposed + // It is consumed automatically by [SessionRunner.Run], but is exposed // for advanced use cases that need direct access to the input stream // (e.g., custom turn loops or fan-out patterns). - InputCh <-chan *AgentFlowInput + InputCh <-chan *SessionFlowInput // TurnIndex is the zero-based index of the current conversation turn. - // It is incremented automatically by [AgentSession.Run], but is exposed + // It is incremented automatically by [SessionRunner.Run], but is exposed // for advanced use cases that need to track or manipulate turn ordering // directly. TurnIndex int @@ -65,16 +65,16 @@ type AgentSession[State any] struct { // wrapped in a trace span for observability. Input messages are automatically // added to the session before fn is called. After fn returns successfully, an // EndTurn chunk is sent and a snapshot check is triggered. -func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentFlowInput) error) error { +func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error { for input := range a.InputCh { spanMeta := &tracing.SpanMetadata{ - Name: fmt.Sprintf("agentFlow/turn/%d", a.TurnIndex), - Type: "agentFlowTurn", - Subtype: "agentFlowTurn", + Name: fmt.Sprintf("sessionFlow/turn/%d", a.TurnIndex), + Type: "flowStep", + Subtype: "flowStep", } _, err := tracing.RunInNewSpan(ctx, spanMeta, input, - func(ctx context.Context, input *AgentFlowInput) (any, error) { + func(ctx context.Context, input *SessionFlowInput) (any, error) { a.AddMessages(input.Messages...) if err := fn(ctx, input); err != nil { @@ -97,15 +97,15 @@ func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Conte return nil } -// Result returns an [AgentFlowResult] populated from the current session state: +// Result returns an [SessionFlowResult] populated from the current session state: // the last message in the conversation history and all artifacts. -// It is a convenience for custom agent flows that don't need to construct the +// It is a convenience for custom session flows that don't need to construct the // result manually. -func (a *AgentSession[State]) Result() *AgentFlowResult { +func (a *SessionRunner[State]) Result() *SessionFlowResult { a.mu.RLock() defer a.mu.RUnlock() - result := &AgentFlowResult{} + result := &SessionFlowResult{} if msgs := a.state.Messages; len(msgs) > 0 { result.Message = msgs[len(msgs)-1] } @@ -119,7 +119,7 @@ func (a *AgentSession[State]) Result() *AgentFlowResult { // maybeSnapshot creates a snapshot if conditions are met (store configured, // callback approves, state changed). Returns the snapshot ID or empty string. -func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { +func (a *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { if a.store == nil { return "" } @@ -163,7 +163,7 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE } if err := a.store.SaveSnapshot(ctx, snapshot); err != nil { - logger.FromContext(ctx).Error("agent flow: failed to save snapshot", "err", err) + logger.FromContext(ctx).Error("session flow: failed to save snapshot", "err", err) return "" } @@ -186,51 +186,51 @@ func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE // --- Responder --- -// Responder is the output channel for an agent flow. Artifacts sent through +// Responder is the output channel for an session flow. Artifacts sent through // it are automatically added to the session before being forwarded to the // client. -type Responder[Stream any] chan<- *AgentFlowStreamChunk[Stream] +type Responder[Stream any] chan<- *SessionFlowStreamChunk[Stream] // SendModelChunk sends a generation chunk (token-level streaming). func (r Responder[Stream]) SendModelChunk(chunk *ai.ModelResponseChunk) { - r <- &AgentFlowStreamChunk[Stream]{ModelChunk: chunk} + r <- &SessionFlowStreamChunk[Stream]{ModelChunk: chunk} } // SendStatus sends a user-defined status update. func (r Responder[Stream]) SendStatus(status Stream) { - r <- &AgentFlowStreamChunk[Stream]{Status: status} + r <- &SessionFlowStreamChunk[Stream]{Status: status} } // SendArtifact sends an artifact to the stream and adds it to the session. // If an artifact with the same name already exists in the session, it is replaced. func (r Responder[Stream]) SendArtifact(artifact *Artifact) { - r <- &AgentFlowStreamChunk[Stream]{Artifact: artifact} + r <- &SessionFlowStreamChunk[Stream]{Artifact: artifact} } -// --- AgentFlow --- +// --- SessionFlow --- -// AgentFlowFunc is the function signature for agent flows. +// SessionFlowFunc is the function signature for session flows. // Type parameters: // - Stream: Type for status updates sent via the responder // - State: Type for user-defined state in snapshots -type AgentFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *AgentSession[State]) (*AgentFlowResult, error) +type SessionFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *SessionRunner[State]) (*SessionFlowResult, error) -// AgentFlow is a bidirectional streaming flow with automatic snapshot management. -type AgentFlow[Stream, State any] struct { - flow *core.Flow[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream], *AgentFlowInit[State]] +// SessionFlow is a bidirectional streaming flow with automatic snapshot management. +type SessionFlow[Stream, State any] struct { + flow *core.Flow[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream], *SessionFlowInit[State]] } -// DefineCustomAgent creates an AgentFlow with automatic snapshot management and registers it. -func DefineCustomAgent[Stream, State any]( +// DefineSessionFlow creates an SessionFlow with automatic snapshot management and registers it. +func DefineSessionFlow[Stream, State any]( r api.Registry, name string, - fn AgentFlowFunc[Stream, State], - opts ...AgentFlowOption[State], -) *AgentFlow[Stream, State] { - afOpts := &agentFlowOptions[State]{} + fn SessionFlowFunc[Stream, State], + opts ...SessionFlowOption[State], +) *SessionFlow[Stream, State] { + afOpts := &sessionFlowOptions[State]{} for _, opt := range opts { - if err := opt.applyAgentFlow(afOpts); err != nil { - panic(fmt.Errorf("DefineCustomAgent %q: %w", name, err)) + if err := opt.applySessionFlow(afOpts); err != nil { + panic(fmt.Errorf("DefineSessionFlow %q: %w", name, err)) } } @@ -239,17 +239,17 @@ func DefineCustomAgent[Stream, State any]( flow := core.DefineBidiFlow(r, name, func( ctx context.Context, - init *AgentFlowInit[State], - inCh <-chan *AgentFlowInput, - outCh chan<- *AgentFlowStreamChunk[Stream], - ) (*AgentFlowOutput[State], error) { + init *SessionFlowInit[State], + inCh <-chan *SessionFlowInput, + outCh chan<- *SessionFlowStreamChunk[Stream], + ) (*SessionFlowOutput[State], error) { session, snapshot, err := newSessionFromInit(ctx, init, store) if err != nil { return nil, err } ctx = NewSessionContext(ctx, session) - agentSess := &AgentSession[State]{ + agentSess := &SessionRunner[State]{ Session: session, snapshotCallback: snapshotCallback, InputCh: inCh, @@ -259,7 +259,7 @@ func DefineCustomAgent[Stream, State any]( // Turn output accumulator: collects content chunks per turn for span output. var ( turnMu sync.Mutex - turnChunks []*AgentFlowStreamChunk[Stream] + turnChunks []*SessionFlowStreamChunk[Stream] ) agentSess.collectTurnOutput = func() any { @@ -272,7 +272,7 @@ func DefineCustomAgent[Stream, State any]( // Intermediary channel: intercepts artifacts, accumulates turn output, // and forwards to outCh. - respCh := make(chan *AgentFlowStreamChunk[Stream]) + respCh := make(chan *SessionFlowStreamChunk[Stream]) var wg sync.WaitGroup wg.Add(1) go func() { @@ -296,9 +296,9 @@ func DefineCustomAgent[Stream, State any]( agentSess.onEndTurn = func(turnCtx context.Context) { snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) if snapshotID != "" { - respCh <- &AgentFlowStreamChunk[Stream]{SnapshotID: snapshotID} + respCh <- &SessionFlowStreamChunk[Stream]{SnapshotID: snapshotID} } - respCh <- &AgentFlowStreamChunk[Stream]{EndTurn: true} + respCh <- &SessionFlowStreamChunk[Stream]{EndTurn: true} } result, fnErr := fn(ctx, Responder[Stream](respCh), agentSess) @@ -317,7 +317,7 @@ func DefineCustomAgent[Stream, State any]( snapshotID = agentSess.lastSnapshot.SnapshotID } - out := &AgentFlowOutput[State]{ + out := &SessionFlowOutput[State]{ SnapshotID: snapshotID, } if result != nil { @@ -333,14 +333,14 @@ func DefineCustomAgent[Stream, State any]( return out, nil }) - return &AgentFlow[Stream, State]{flow: flow} + return &SessionFlow[Stream, State]{flow: flow} } // promptMessageKey is the metadata key used to tag prompt-rendered messages // so they can be excluded from session history after generation. const promptMessageKey = "_genkit_prompt" -// DefinePromptAgent creates a prompt-backed AgentFlow with an +// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an // automatic conversation loop. Each turn renders the prompt, appends // conversation history, calls GenerateWithRequest, streams chunks to the // client, and adds the model response to the session. @@ -348,19 +348,19 @@ const promptMessageKey = "_genkit_prompt" // The prompt is looked up by name from the registry using // [ai.LookupDataPrompt]. The defaultInput is used for prompt rendering // unless overridden per invocation via WithInputVariables. -func DefinePromptAgent[State, PromptIn any]( +func DefineSessionFlowFromPrompt[State, PromptIn any]( r api.Registry, promptName string, defaultInput PromptIn, - opts ...AgentFlowOption[State], -) *AgentFlow[any, State] { + opts ...SessionFlowOption[State], +) *SessionFlow[any, State] { p := ai.LookupDataPrompt[PromptIn, string](r, promptName) if p == nil { - panic(fmt.Sprintf("DefinePromptAgent: prompt %q not found", promptName)) + panic(fmt.Sprintf("DefineSessionFlowFromPrompt: prompt %q not found", promptName)) } - fn := func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) (*AgentFlowResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + fn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*SessionFlowResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { // Resolve prompt input: session state override > default. promptInput := defaultInput if stored := sess.InputVariables(); stored != nil { @@ -445,16 +445,16 @@ func DefinePromptAgent[State, PromptIn any]( return sess.Result(), nil } - return DefineCustomAgent(r, promptName, fn, opts...) + return DefineSessionFlow(r, promptName, fn, opts...) } -// StreamBidi starts a new agent flow invocation with bidirectional streaming. +// StreamBidi starts a new session flow invocation with bidirectional streaming. // Use this for multi-turn interactions where you need to send multiple inputs // and receive streaming chunks. For single-turn usage, see Run and RunText. -func (af *AgentFlow[Stream, State]) StreamBidi( +func (af *SessionFlow[Stream, State]) StreamBidi( ctx context.Context, opts ...InvocationOption[State], -) (*AgentFlowConnection[Stream, State], error) { +) (*SessionFlowConnection[Stream, State], error) { invOpts, err := af.resolveOptions(opts) if err != nil { return nil, err @@ -465,17 +465,17 @@ func (af *AgentFlow[Stream, State]) StreamBidi( return nil, err } - return &AgentFlowConnection[Stream, State]{conn: conn}, nil + return &SessionFlowConnection[Stream, State]{conn: conn}, nil } -// Run starts a single-turn agent flow invocation with the given input. +// Run starts a single-turn session flow invocation with the given input. // It sends the input, waits for the flow to complete, and returns the output. // For multi-turn interactions or streaming, use StreamBidi instead. -func (af *AgentFlow[Stream, State]) Run( +func (af *SessionFlow[Stream, State]) Run( ctx context.Context, - input *AgentFlowInput, + input *SessionFlowInput, opts ...InvocationOption[State], -) (*AgentFlowOutput[State], error) { +) (*SessionFlowOutput[State], error) { conn, err := af.StreamBidi(ctx, opts...) if err != nil { return nil, err @@ -498,29 +498,29 @@ func (af *AgentFlow[Stream, State]) Run( return conn.Output() } -// RunText is a convenience method that starts a single-turn agent flow +// RunText is a convenience method that starts a single-turn session flow // invocation with a user text message. It is equivalent to calling Run with -// an AgentFlowInput containing a single user text message. -func (af *AgentFlow[Stream, State]) RunText( +// an SessionFlowInput containing a single user text message. +func (af *SessionFlow[Stream, State]) RunText( ctx context.Context, text string, opts ...InvocationOption[State], -) (*AgentFlowOutput[State], error) { - return af.Run(ctx, &AgentFlowInput{ +) (*SessionFlowOutput[State], error) { + return af.Run(ctx, &SessionFlowInput{ Messages: []*ai.Message{ai.NewUserTextMessage(text)}, }, opts...) } // resolveOptions applies invocation options and returns the init struct. -func (af *AgentFlow[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*AgentFlowInit[State], error) { +func (af *SessionFlow[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*SessionFlowInit[State], error) { invOpts := &invocationOptions[State]{} for _, opt := range opts { if err := opt.applyInvocation(invOpts); err != nil { - return nil, fmt.Errorf("AgentFlow %q: %w", af.flow.Name(), err) + return nil, fmt.Errorf("SessionFlow %q: %w", af.flow.Name(), err) } } - init := &AgentFlowInit[State]{ + init := &SessionFlowInit[State]{ SnapshotID: invOpts.snapshotID, State: invOpts.state, } @@ -538,7 +538,7 @@ func (af *AgentFlow[Stream, State]) resolveOptions(opts []InvocationOption[State // If resuming from a snapshot, the loaded snapshot is also returned. func newSessionFromInit[State any]( ctx context.Context, - init *AgentFlowInit[State], + init *SessionFlowInit[State], store SessionStore[State], ) (*Session[State], *SessionSnapshot[State], error) { s := &Session[State]{store: store} @@ -569,16 +569,16 @@ func newSessionFromInit[State any]( return s, snapshot, nil } -// --- AgentFlowConnection --- +// --- SessionFlowConnection --- -// AgentFlowConnection wraps BidiConnection with agent flow-specific functionality. +// SessionFlowConnection wraps BidiConnection with session flow-specific functionality. // It provides a Receive() iterator that supports multi-turn patterns: breaking out // of the iterator between turns does not cancel the underlying connection. -type AgentFlowConnection[Stream, State any] struct { - conn *core.BidiConnection[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream]] +type SessionFlowConnection[Stream, State any] struct { + conn *core.BidiConnection[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream]] // chunks buffers stream chunks from the underlying connection so that // breaking from Receive() between turns doesn't cancel the context. - chunks chan *AgentFlowStreamChunk[Stream] + chunks chan *SessionFlowStreamChunk[Stream] chunkErr error initOnce sync.Once } @@ -586,9 +586,9 @@ type AgentFlowConnection[Stream, State any] struct { // initReceiver starts a goroutine that drains the underlying BidiConnection's // Receive into a channel. This goroutine never breaks from the underlying // iterator, preventing context cancellation. -func (c *AgentFlowConnection[Stream, State]) initReceiver() { +func (c *SessionFlowConnection[Stream, State]) initReceiver() { c.initOnce.Do(func() { - c.chunks = make(chan *AgentFlowStreamChunk[Stream], 1) + c.chunks = make(chan *SessionFlowStreamChunk[Stream], 1) go func() { defer close(c.chunks) for chunk, err := range c.conn.Receive() { @@ -602,31 +602,31 @@ func (c *AgentFlowConnection[Stream, State]) initReceiver() { }) } -// Send sends an AgentFlowInput to the agent flow. -func (c *AgentFlowConnection[Stream, State]) Send(input *AgentFlowInput) error { +// Send sends an SessionFlowInput to the session flow. +func (c *SessionFlowConnection[Stream, State]) Send(input *SessionFlowInput) error { return c.conn.Send(input) } -// SendMessages sends messages to the agent flow. -func (c *AgentFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { - return c.conn.Send(&AgentFlowInput{Messages: messages}) +// SendMessages sends messages to the session flow. +func (c *SessionFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { + return c.conn.Send(&SessionFlowInput{Messages: messages}) } -// SendText sends a single user text message to the agent flow. -func (c *AgentFlowConnection[Stream, State]) SendText(text string) error { - return c.conn.Send(&AgentFlowInput{ +// SendText sends a single user text message to the session flow. +func (c *SessionFlowConnection[Stream, State]) SendText(text string) error { + return c.conn.Send(&SessionFlowInput{ Messages: []*ai.Message{ai.NewUserTextMessage(text)}, }) } // SendToolRestarts sends tool restart parts to resume interrupted tool calls. // Parts should be created via [ai.ToolDef.RestartWith]. -func (c *AgentFlowConnection[Stream, State]) SendToolRestarts(parts ...*ai.Part) error { - return c.conn.Send(&AgentFlowInput{ToolRestarts: parts}) +func (c *SessionFlowConnection[Stream, State]) SendToolRestarts(parts ...*ai.Part) error { + return c.conn.Send(&SessionFlowInput{ToolRestarts: parts}) } // Close signals that no more inputs will be sent. -func (c *AgentFlowConnection[Stream, State]) Close() error { +func (c *SessionFlowConnection[Stream, State]) Close() error { return c.conn.Close() } @@ -634,9 +634,9 @@ func (c *AgentFlowConnection[Stream, State]) Close() error { // Unlike the underlying BidiConnection.Receive, breaking out of this iterator // does not cancel the connection. This enables multi-turn patterns where the // caller breaks on EndTurn, sends the next input, then calls Receive again. -func (c *AgentFlowConnection[Stream, State]) Receive() iter.Seq2[*AgentFlowStreamChunk[Stream], error] { +func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { c.initReceiver() - return func(yield func(*AgentFlowStreamChunk[Stream], error) bool) { + return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { for { chunk, ok := <-c.chunks if !ok { @@ -652,12 +652,12 @@ func (c *AgentFlowConnection[Stream, State]) Receive() iter.Seq2[*AgentFlowStrea } } -// Output returns the final response after the agent flow completes. -func (c *AgentFlowConnection[Stream, State]) Output() (*AgentFlowOutput[State], error) { +// Output returns the final response after the session flow completes. +func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[State], error) { return c.conn.Output() } // Done returns a channel closed when the connection completes. -func (c *AgentFlowConnection[Stream, State]) Done() <-chan struct{} { +func (c *SessionFlowConnection[Stream, State]) Done() <-chan struct{} { return c.conn.Done() } diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index d2d75c232f..1d8d0ddcb5 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -40,13 +40,13 @@ func newTestRegistry(t *testing.T) *registry.Registry { return registry.New() } -func TestAgentFlow_BasicMultiTurn(t *testing.T) { +func TestSessionFlow_BasicMultiTurn(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "basicFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "basicFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. if len(input.Messages) > 0 { @@ -115,14 +115,14 @@ func TestAgentFlow_BasicMultiTurn(t *testing.T) { } } -func TestAgentFlow_WithSessionStore(t *testing.T) { +func TestSessionFlow_WithSessionStore(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineCustomAgent(reg, "snapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "snapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -185,14 +185,14 @@ func TestAgentFlow_WithSessionStore(t *testing.T) { } } -func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { +func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineCustomAgent(reg, "resumeFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "resumeFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -275,13 +275,13 @@ func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { } } -func TestAgentFlow_ClientManagedState(t *testing.T) { +func TestSessionFlow_ClientManagedState(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "clientStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "clientStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -338,13 +338,13 @@ func TestAgentFlow_ClientManagedState(t *testing.T) { } } -func TestAgentFlow_Artifacts(t *testing.T) { +func TestSessionFlow_Artifacts(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "artifactFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "artifactFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { resp.SendArtifact(&Artifact{ Name: "code.go", @@ -369,7 +369,7 @@ func TestAgentFlow_Artifacts(t *testing.T) { if err != nil { return nil, err } - return &AgentFlowResult{Artifacts: sess.Artifacts()}, nil + return &SessionFlowResult{Artifacts: sess.Artifacts()}, nil }, ) @@ -408,16 +408,16 @@ func TestAgentFlow_Artifacts(t *testing.T) { } } -func TestAgentFlow_SnapshotCallback(t *testing.T) { +func TestSessionFlow_SnapshotCallback(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() // Only snapshot on even turns. callbackCalls := 0 - af := DefineCustomAgent(reg, "callbackFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "callbackFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -467,13 +467,13 @@ func TestAgentFlow_SnapshotCallback(t *testing.T) { } } -func TestAgentFlow_SendMessages(t *testing.T) { +func TestSessionFlow_SendMessages(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "sendMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "sendMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { return nil }) }, @@ -513,14 +513,14 @@ func TestAgentFlow_SendMessages(t *testing.T) { } } -func TestAgentFlow_SessionContext(t *testing.T) { +func TestSessionFlow_SessionContext(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) var retrievedCounter int - af := DefineCustomAgent(reg, "contextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "contextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { // Session should be retrievable from context. ctxSess := SessionFromContext[testState](ctx) if ctxSess == nil { @@ -559,13 +559,13 @@ func TestAgentFlow_SessionContext(t *testing.T) { } } -func TestAgentFlow_ErrorInTurn(t *testing.T) { +func TestSessionFlow_ErrorInTurn(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "errorFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "errorFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { return fmt.Errorf("turn failed") }) }, @@ -585,13 +585,13 @@ func TestAgentFlow_ErrorInTurn(t *testing.T) { } } -func TestAgentFlow_SetMessages(t *testing.T) { +func TestSessionFlow_SetMessages(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "setMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "setMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { // Replace all messages with just one. sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) return nil @@ -626,14 +626,14 @@ func TestAgentFlow_SetMessages(t *testing.T) { } } -func TestAgentFlow_SnapshotIDInMessageMetadata(t *testing.T) { +func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineCustomAgent(reg, "metadataFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "metadataFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil }) @@ -641,7 +641,7 @@ func TestAgentFlow_SnapshotIDInMessageMetadata(t *testing.T) { return nil, err } msgs := sess.Messages() - return &AgentFlowResult{Message: msgs[len(msgs)-1]}, nil + return &SessionFlowResult{Message: msgs[len(msgs)-1]}, nil }, WithSessionStore(store), ) @@ -722,14 +722,14 @@ func TestInMemorySessionStore(t *testing.T) { } } -func TestAgentFlow_TurnSpanOutput(t *testing.T) { +func TestSessionFlow_TurnSpanOutput(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) var capturedOutputs []any - af := DefineCustomAgent(reg, "turnOutputFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + af := DefineSessionFlow(reg, "turnOutputFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { // Wrap collectTurnOutput to capture what each turn produces. originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { @@ -738,7 +738,7 @@ func TestAgentFlow_TurnSpanOutput(t *testing.T) { return output } - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { resp.SendStatus(testStatus{Phase: "thinking"}) resp.SendModelChunk(&ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("reply")}, @@ -784,9 +784,9 @@ func TestAgentFlow_TurnSpanOutput(t *testing.T) { } for i, output := range capturedOutputs { - chunks, ok := output.([]*AgentFlowStreamChunk[testStatus]) + chunks, ok := output.([]*SessionFlowStreamChunk[testStatus]) if !ok { - t.Fatalf("turn %d: expected []*AgentFlowStreamChunk[testStatus], got %T", i, output) + t.Fatalf("turn %d: expected []*SessionFlowStreamChunk[testStatus], got %T", i, output) } // 3 content chunks per turn: status + model chunk + artifact. if len(chunks) != 3 { @@ -803,15 +803,15 @@ func TestAgentFlow_TurnSpanOutput(t *testing.T) { } } -func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { +func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() var capturedOutputs []any - af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + af := DefineSessionFlow(reg, "turnOutputSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { output := originalCollect() @@ -819,7 +819,7 @@ func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { return output } - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { resp.SendStatus(testStatus{Phase: "working"}) sess.AddMessages(ai.NewModelTextMessage("reply")) return nil @@ -857,7 +857,7 @@ func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { if len(capturedOutputs) != 1 { t.Fatalf("expected 1 captured output, got %d", len(capturedOutputs)) } - chunks := capturedOutputs[0].([]*AgentFlowStreamChunk[testStatus]) + chunks := capturedOutputs[0].([]*SessionFlowStreamChunk[testStatus]) if len(chunks) != 1 { t.Errorf("expected 1 content chunk, got %d", len(chunks)) } @@ -916,7 +916,7 @@ func TestPromptAgent_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState, any]( + af := DefineSessionFlowFromPrompt[testState, any]( reg, "testPrompt", nil, ) @@ -988,7 +988,7 @@ func TestPromptAgent_PromptInputOverride(t *testing.T) { ai.WithPrompt("Hello {{name}}!"), ) - af := DefinePromptAgent[testState]( + af := DefineSessionFlowFromPrompt[testState]( reg, "greetPrompt", greetInput{Name: "default"}, ) @@ -1066,7 +1066,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - af := DefinePromptAgent[testState, any]( + af := DefineSessionFlowFromPrompt[testState, any]( reg, "historyPrompt", nil, ) @@ -1142,7 +1142,7 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState, any]( + af := DefineSessionFlowFromPrompt[testState, any]( reg, "snapPrompt", nil, WithSessionStore(store), ) @@ -1274,7 +1274,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ai.WithTools(ai.ToolName("greet")), ) - af := DefinePromptAgent[testState, any](reg, "toolPrompt", nil) + af := DefineSessionFlowFromPrompt[testState, any](reg, "toolPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1332,13 +1332,13 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { } } -func TestAgentFlow_RunText(t *testing.T) { +func TestSessionFlow_RunText(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "runTextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "runTextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text)) } @@ -1365,13 +1365,13 @@ func TestAgentFlow_RunText(t *testing.T) { } } -func TestAgentFlow_Run(t *testing.T) { +func TestSessionFlow_Run(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "runFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "runFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -1380,7 +1380,7 @@ func TestAgentFlow_Run(t *testing.T) { }, ) - input := &AgentFlowInput{ + input := &SessionFlowInput{ Messages: []*ai.Message{ ai.NewUserTextMessage("msg1"), ai.NewUserTextMessage("msg2"), @@ -1398,13 +1398,13 @@ func TestAgentFlow_Run(t *testing.T) { } } -func TestAgentFlow_RunText_WithState(t *testing.T) { +func TestSessionFlow_RunText_WithState(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "runStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "runStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1438,14 +1438,14 @@ func TestAgentFlow_RunText_WithState(t *testing.T) { } } -func TestAgentFlow_RunText_WithSnapshot(t *testing.T) { +func TestSessionFlow_RunText_WithSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineCustomAgent(reg, "runSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "runSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1494,7 +1494,7 @@ func TestPromptAgent_RunText(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState, any](reg, "runTextPrompt", nil) + af := DefineSessionFlowFromPrompt[testState, any](reg, "runTextPrompt", nil) response, err := af.RunText(ctx, "hello") if err != nil { @@ -1510,14 +1510,14 @@ func TestPromptAgent_RunText(t *testing.T) { } } -func TestAgentFlow_SingleTurnSnapshotDedup(t *testing.T) { +func TestSessionFlow_SingleTurnSnapshotDedup(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineCustomAgent(reg, "dedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "dedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1554,14 +1554,14 @@ func TestAgentFlow_SingleTurnSnapshotDedup(t *testing.T) { } } -func TestAgentFlow_MultiTurnSnapshotDedup(t *testing.T) { +func TestSessionFlow_MultiTurnSnapshotDedup(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineCustomAgent(reg, "multiDedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "multiDedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1616,14 +1616,14 @@ func TestAgentFlow_MultiTurnSnapshotDedup(t *testing.T) { } } -func TestAgentFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { +func TestSessionFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineCustomAgent(reg, "postRunMutateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + af := DefineSessionFlow(reg, "postRunMutateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil }); err != nil { diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index b13c62aa86..b27ace232f 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -22,9 +22,9 @@ import ( "github.com/firebase/genkit/go/ai" ) -// AgentFlowInit is the input for starting an agent flow invocation. +// SessionFlowInit is the input for starting an session flow invocation. // Provide either SnapshotID (to load from store) or State (direct state). -type AgentFlowInit[State any] struct { +type SessionFlowInit[State any] struct { // SnapshotID loads state from a persisted snapshot. // Mutually exclusive with State. SnapshotID string `json:"snapshotId,omitempty"` @@ -33,8 +33,8 @@ type AgentFlowInit[State any] struct { State *SessionState[State] `json:"state,omitempty"` } -// AgentFlowInput is the input sent to an agent flow during a conversation turn. -type AgentFlowInput struct { +// SessionFlowInput is the input sent to an session flow during a conversation turn. +type SessionFlowInput struct { // Messages contains the user's input for this turn. Messages []*ai.Message `json:"messages,omitempty"` // ToolRestarts contains tool request parts to re-execute interrupted tools. @@ -44,9 +44,9 @@ type AgentFlowInput struct { ToolRestarts []*ai.Part `json:"toolRestarts,omitempty"` } -// AgentFlowOutput is the output when an agent flow invocation completes. -// It wraps AgentFlowResult with framework-managed fields. -type AgentFlowOutput[State any] struct { +// SessionFlowOutput is the output when an session flow invocation completes. +// It wraps SessionFlowResult with framework-managed fields. +type SessionFlowOutput[State any] struct { // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` // Message is the last model response message from the conversation. @@ -59,21 +59,21 @@ type AgentFlowOutput[State any] struct { State *SessionState[State] `json:"state,omitempty"` } -// AgentFlowResult is the return value from an AgentFlowFunc. +// SessionFlowResult is the return value from an SessionFlowFunc. // It contains the user-specified outputs of the agent invocation. -type AgentFlowResult struct { +type SessionFlowResult struct { // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` } -// AgentFlowStreamChunk represents a single item in the agent flow's output stream. +// SessionFlowStreamChunk represents a single item in the session flow's output stream. // Multiple fields can be populated in a single chunk. -type AgentFlowStreamChunk[Stream any] struct { +type SessionFlowStreamChunk[Stream any] struct { // Artifact contains a newly produced artifact. Artifact *Artifact `json:"artifact,omitempty"` - // EndTurn signals that the agent flow has finished processing the current input. + // EndTurn signals that the session flow has finished processing the current input. // When true, the client should stop iterating and may send the next input. EndTurn bool `json:"endTurn,omitempty"` // ModelChunk contains generation tokens from the model. @@ -103,8 +103,8 @@ type SessionState[State any] struct { Artifacts []*Artifact `json:"artifacts,omitempty"` // Custom is the user-defined state associated with this conversation. Custom State `json:"custom,omitempty"` - // InputVariables is the input used for agent flows that require input variables - // (e.g. prompt-backed agent flows). + // InputVariables is the input used for session flows that require input variables + // (e.g. prompt-backed session flows). InputVariables any `json:"inputVariables,omitempty"` // Messages is the conversation history (user/model exchanges). // Does NOT include prompt-rendered messages — those are rendered fresh each turn. diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index f7e80639ca..c22e63956f 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -21,19 +21,19 @@ import ( "errors" ) -// --- AgentFlowOption --- +// --- SessionFlowOption --- -// AgentFlowOption configures an AgentFlow. -type AgentFlowOption[State any] interface { - applyAgentFlow(*agentFlowOptions[State]) error +// SessionFlowOption configures an SessionFlow. +type SessionFlowOption[State any] interface { + applySessionFlow(*sessionFlowOptions[State]) error } -type agentFlowOptions[State any] struct { +type sessionFlowOptions[State any] struct { store SessionStore[State] callback SnapshotCallback[State] } -func (o *agentFlowOptions[State]) applyAgentFlow(opts *agentFlowOptions[State]) error { +func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error { if o.store != nil { if opts.store != nil { return errors.New("cannot set session store more than once (WithSessionStore)") @@ -50,20 +50,20 @@ func (o *agentFlowOptions[State]) applyAgentFlow(opts *agentFlowOptions[State]) } // WithSessionStore sets the store for persisting snapshots. -func WithSessionStore[State any](store SessionStore[State]) AgentFlowOption[State] { - return &agentFlowOptions[State]{store: store} +func WithSessionStore[State any](store SessionStore[State]) SessionFlowOption[State] { + return &sessionFlowOptions[State]{store: store} } // WithSnapshotCallback configures when snapshots are created. // If not provided and a store is configured, snapshots are always created. -func WithSnapshotCallback[State any](cb SnapshotCallback[State]) AgentFlowOption[State] { - return &agentFlowOptions[State]{callback: cb} +func WithSnapshotCallback[State any](cb SnapshotCallback[State]) SessionFlowOption[State] { + return &sessionFlowOptions[State]{callback: cb} } // WithSnapshotOn configures snapshots to be created only for the specified events. // For example, WithSnapshotOn[MyState](SnapshotEventTurnEnd) skips the // invocation-end snapshot. -func WithSnapshotOn[State any](events ...SnapshotEvent) AgentFlowOption[State] { +func WithSnapshotOn[State any](events ...SnapshotEvent) SessionFlowOption[State] { set := make(map[SnapshotEvent]struct{}, len(events)) for _, e := range events { set[e] = struct{}{} @@ -76,7 +76,7 @@ func WithSnapshotOn[State any](events ...SnapshotEvent) AgentFlowOption[State] { // --- InvocationOption --- -// InvocationOption configures an agent flow invocation (StreamBidi, Run, or RunText). +// InvocationOption configures an session flow invocation (StreamBidi, Run, or RunText). type InvocationOption[State any] interface { applyInvocation(*invocationOptions[State]) error } @@ -127,8 +127,8 @@ func WithSnapshotID[State any](id string) InvocationOption[State] { return &invocationOptions[State]{snapshotID: id} } -// WithInputVariables overrides the default input variables for a prompt-backed agent flow. -// Used with DefinePromptAgent to customize the input variables per invocation. +// WithInputVariables overrides the default input variables for a prompt-backed session flow. +// Used with DefineSessionFlowFromPrompt to customize the input variables per invocation. func WithInputVariables[State any](input any) InvocationOption[State] { return &invocationOptions[State]{promptInput: input} } diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 40f0d6dfc8..e65d33d426 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -249,11 +249,11 @@ func (s *Session[State]) UpdateArtifacts(fn func([]*Artifact) []*Artifact) { func (s *Session[State]) copyStateLocked() SessionState[State] { bytes, err := json.Marshal(s.state) if err != nil { - panic(fmt.Sprintf("agent flow: failed to marshal state: %v", err)) + panic(fmt.Sprintf("session flow: failed to marshal state: %v", err)) } var copied SessionState[State] if err := json.Unmarshal(bytes, &copied); err != nil { - panic(fmt.Sprintf("agent flow: failed to unmarshal state: %v", err)) + panic(fmt.Sprintf("session flow: failed to unmarshal state: %v", err)) } return copied } diff --git a/go/core/api/action.go b/go/core/api/action.go index 704fb1b9f0..901844bd9c 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -64,7 +64,7 @@ const ( ActionTypeCustom ActionType = "custom" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" - ActionTypeAgentFlow ActionType = "agent-flow" + ActionTypeSessionFlow ActionType = "session-flow" ) // ActionDesc is a descriptor of an action. diff --git a/go/core/schemas.config b/go/core/schemas.config index b58c36e2a9..9cafa3b092 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1144,22 +1144,22 @@ Metadata contains additional artifact-specific data. . # ---------------------------------------------------------------------------- -# AgentFlowInput +# SessionFlowInput # ---------------------------------------------------------------------------- -AgentFlowInput pkg ai/x +SessionFlowInput pkg ai/x -AgentFlowInput doc -AgentFlowInput is the input sent to an agent flow during a conversation turn. +SessionFlowInput doc +SessionFlowInput is the input sent to an session flow during a conversation turn. . -AgentFlowInput.messages type []*ai.Message -AgentFlowInput.messages doc +SessionFlowInput.messages type []*ai.Message +SessionFlowInput.messages doc Messages contains the user's input for this turn. . -AgentFlowInput.toolRestarts type []*ai.Part -AgentFlowInput.toolRestarts doc +SessionFlowInput.toolRestarts type []*ai.Part +SessionFlowInput.toolRestarts doc ToolRestarts contains tool request parts to re-execute interrupted tools. Use [ai.ToolDef.RestartWith] to create these parts from an interrupted tool request. When set, the generate call resumes with these restarts @@ -1167,113 +1167,113 @@ instead of treating Messages as tool responses. . # ---------------------------------------------------------------------------- -# AgentFlowInit +# SessionFlowInit # ---------------------------------------------------------------------------- -AgentFlowInit pkg ai/x -AgentFlowInit typeparams [State any] +SessionFlowInit pkg ai/x +SessionFlowInit typeparams [State any] -AgentFlowInit doc -AgentFlowInit is the input for starting an agent flow invocation. +SessionFlowInit doc +SessionFlowInit is the input for starting an session flow invocation. Provide either SnapshotID (to load from store) or State (direct state). . -AgentFlowInit.snapshotId doc +SessionFlowInit.snapshotId doc SnapshotID loads state from a persisted snapshot. Mutually exclusive with State. . -AgentFlowInit.state type *SessionState[State] -AgentFlowInit.state doc +SessionFlowInit.state type *SessionState[State] +SessionFlowInit.state doc State provides direct state for the invocation. Mutually exclusive with SnapshotID. . # ---------------------------------------------------------------------------- -# AgentFlowResult +# SessionFlowResult # ---------------------------------------------------------------------------- -AgentFlowResult pkg ai/x +SessionFlowResult pkg ai/x -AgentFlowResult doc -AgentFlowResult is the return value from an AgentFlowFunc. +SessionFlowResult doc +SessionFlowResult is the return value from an SessionFlowFunc. It contains the user-specified outputs of the agent invocation. . -AgentFlowResult.message type *ai.Message -AgentFlowResult.message doc +SessionFlowResult.message type *ai.Message +SessionFlowResult.message doc Message is the last model response message from the conversation. . -AgentFlowResult.artifacts doc +SessionFlowResult.artifacts doc Artifacts contains artifacts produced during the session. . # ---------------------------------------------------------------------------- -# AgentFlowOutput +# SessionFlowOutput # ---------------------------------------------------------------------------- -AgentFlowOutput pkg ai/x -AgentFlowOutput typeparams [State any] +SessionFlowOutput pkg ai/x +SessionFlowOutput typeparams [State any] -AgentFlowOutput doc -AgentFlowOutput is the output when an agent flow invocation completes. -It wraps AgentFlowResult with framework-managed fields. +SessionFlowOutput doc +SessionFlowOutput is the output when an session flow invocation completes. +It wraps SessionFlowResult with framework-managed fields. . -AgentFlowOutput.snapshotId doc +SessionFlowOutput.snapshotId doc SnapshotID is the ID of the snapshot created at the end of this invocation. Empty if no snapshot was created (callback returned false or no store configured). . -AgentFlowOutput.state type *SessionState[State] -AgentFlowOutput.state doc +SessionFlowOutput.state type *SessionState[State] +SessionFlowOutput.state doc State contains the final conversation state. Only populated when state is client-managed (no store configured). . -AgentFlowOutput.message type *ai.Message -AgentFlowOutput.message doc +SessionFlowOutput.message type *ai.Message +SessionFlowOutput.message doc Message is the last model response message from the conversation. . -AgentFlowOutput.artifacts doc +SessionFlowOutput.artifacts doc Artifacts contains artifacts produced during the session. . # ---------------------------------------------------------------------------- -# AgentFlowStreamChunk +# SessionFlowStreamChunk # ---------------------------------------------------------------------------- -AgentFlowStreamChunk pkg ai/x -AgentFlowStreamChunk typeparams [Stream any] +SessionFlowStreamChunk pkg ai/x +SessionFlowStreamChunk typeparams [Stream any] -AgentFlowStreamChunk doc -AgentFlowStreamChunk represents a single item in the agent flow's output stream. +SessionFlowStreamChunk doc +SessionFlowStreamChunk represents a single item in the session flow's output stream. Multiple fields can be populated in a single chunk. . -AgentFlowStreamChunk.modelChunk type *ai.ModelResponseChunk -AgentFlowStreamChunk.modelChunk doc +SessionFlowStreamChunk.modelChunk type *ai.ModelResponseChunk +SessionFlowStreamChunk.modelChunk doc ModelChunk contains generation tokens from the model. . -AgentFlowStreamChunk.status type Stream -AgentFlowStreamChunk.status doc +SessionFlowStreamChunk.status type Stream +SessionFlowStreamChunk.status doc Status contains user-defined structured status information. The Stream type parameter defines the shape of this data. . -AgentFlowStreamChunk.artifact doc +SessionFlowStreamChunk.artifact doc Artifact contains a newly produced artifact. . -AgentFlowStreamChunk.snapshotId doc +SessionFlowStreamChunk.snapshotId doc SnapshotID contains the ID of a snapshot that was just persisted. . -AgentFlowStreamChunk.endTurn doc -EndTurn signals that the agent flow has finished processing the current input. +SessionFlowStreamChunk.endTurn doc +EndTurn signals that the session flow has finished processing the current input. When true, the client should stop iterating and may send the next input. . @@ -1305,8 +1305,8 @@ Artifacts are named collections of parts produced during the conversation. . SessionState.inputVariables doc -InputVariables is the input used for agent flows that require input variables -(e.g. prompt-backed agent flows). +InputVariables is the input used for session flows that require input variables +(e.g. prompt-backed session flows). . # ---------------------------------------------------------------------------- diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 21f312de6a..284d22166f 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -408,26 +408,26 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B return core.DefineBidiFlow(g.reg, name, fn) } -// DefineCustomAgent defines a custom agent flow with full control over the +// DefineSessionFlow defines a custom session flow with full control over the // conversation loop, registers it as a [core.Action] of type Flow, and -// returns an [aix.AgentFlow]. +// returns an [aix.SessionFlow]. // // Experimental: This API is under active development and may change in any // minor version release. // -// An AgentFlow is a stateful, multi-turn conversational flow. It builds on +// An SessionFlow is a stateful, multi-turn conversational flow. It builds on // bidirectional streaming to enable ongoing conversations where each turn's // input and output are streamed between client and server. The framework // handles session state, conversation history, and optional snapshot // persistence automatically. // // The provided function fn receives a [aix.Responder] for streaming output -// to the client and an [aix.AgentSession] for accessing conversation state. -// Call [aix.AgentSession.Run] to enter the turn loop, which blocks until the +// to the client and an [aix.SessionRunner] for accessing conversation state. +// Call [aix.SessionRunner.Run] to enter the turn loop, which blocks until the // client sends the next message. // // For prompt-backed agents that follow a standard render-generate-stream loop, -// use [DefinePromptAgent] instead. +// use [DefineSessionFlowFromPrompt] instead. // // # Options // @@ -441,10 +441,10 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B // // Example: // -// chatAgent := genkit.DefineCustomAgent(g, "chat", -// func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentFlowResult, error) { +// chatAgent := genkit.DefineSessionFlow(g, "chat", +// func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.SessionFlowResult, error) { // var lastMessage *ai.Message -// err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { +// err := sess.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { // sess.AddMessages(input.Messages...) // for result, err := range genkit.GenerateStream(ctx, g, // ai.WithModelName("googleai/gemini-3-flash-preview"), @@ -465,7 +465,7 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B // if err != nil { // return nil, err // } -// return &aix.AgentFlowResult{Message: lastMessage}, nil +// return &aix.SessionFlowResult{Message: lastMessage}, nil // }, // ) // @@ -484,22 +484,22 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B // fmt.Print(chunk.ModelChunk.Text()) // } // conn.Close() -func DefineCustomAgent[Stream, State any]( +func DefineSessionFlow[Stream, State any]( g *Genkit, name string, - fn aix.AgentFlowFunc[Stream, State], - opts ...aix.AgentFlowOption[State], -) *aix.AgentFlow[Stream, State] { - return aix.DefineCustomAgent(g.reg, name, fn, opts...) + fn aix.SessionFlowFunc[Stream, State], + opts ...aix.SessionFlowOption[State], +) *aix.SessionFlow[Stream, State] { + return aix.DefineSessionFlow(g.reg, name, fn, opts...) } -// DefinePromptAgent defines a prompt-backed agent flow, registers it as a -// [core.Action] of type Flow, and returns an [aix.AgentFlow]. +// DefineSessionFlowFromPrompt defines a prompt-backed session flow, registers it as a +// [core.Action] of type Flow, and returns an [aix.SessionFlow]. // // Experimental: This API is under active development and may change in any // minor version release. // -// This is a higher-level alternative to [DefineCustomAgent] for agents backed +// This is a higher-level alternative to [DefineSessionFlow] for agents backed // by a prompt (defined via [DefinePrompt] or loaded from a .prompt file). The // conversation loop is handled automatically: each turn renders the prompt, // appends conversation history, calls the model with streaming, and updates @@ -509,8 +509,8 @@ func DefineCustomAgent[Stream, State any]( // provides template variables for prompt rendering (e.g., personality, tone) // and can be overridden per invocation via [aix.WithInputVariables]. // -// DefinePromptAgent accepts the same options as [DefineCustomAgent]. See -// [DefineCustomAgent] for available options. +// DefineSessionFlowFromPrompt accepts the same options as [DefineSessionFlow]. See +// [DefineSessionFlow] for available options. // // Type parameters: // - State: Type for user-defined state persisted in snapshots @@ -532,7 +532,7 @@ func DefineCustomAgent[Stream, State any]( // Personality string `json:"personality"` // } // -// chatAgent := genkit.DefinePromptAgent(g, "chat", +// chatAgent := genkit.DefineSessionFlowFromPrompt(g, "chat", // ChatInput{Personality: "a helpful assistant"}, // aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), // ) @@ -552,13 +552,13 @@ func DefineCustomAgent[Stream, State any]( // fmt.Print(chunk.ModelChunk.Text()) // } // conn.Close() -func DefinePromptAgent[State, PromptIn any]( +func DefineSessionFlowFromPrompt[State, PromptIn any]( g *Genkit, promptName string, defaultInput PromptIn, - opts ...aix.AgentFlowOption[State], -) *aix.AgentFlow[any, State] { - return aix.DefinePromptAgent(g.reg, promptName, defaultInput, opts...) + opts ...aix.SessionFlowOption[State], +) *aix.SessionFlow[any, State] { + return aix.DefineSessionFlowFromPrompt(g.reg, promptName, defaultInput, opts...) } // Run executes the given function `fn` within the context of the current flow run, diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index a092e88ad2..b8e2274e0c 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates the AgentFlow API for multi-turn conversation +// This sample demonstrates the SessionFlow API for multi-turn conversation // with token-level streaming. It runs a CLI REPL where conversation history // is managed automatically by the session. package main @@ -35,9 +35,9 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatFlow := genkit.DefineCustomAgent(g, "chat", - func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentFlowResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { + chatFlow := genkit.DefineSessionFlow(g, "chat", + func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.SessionFlowResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ @@ -67,7 +67,7 @@ func main() { aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) - fmt.Println("Agent Flow Chat (type 'quit' to exit)") + fmt.Println("Session Flow Chat (type 'quit' to exit)") fmt.Println() conn, err := chatFlow.StreamBidi(ctx) diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index 0f73c52c82..e46ff2067e 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates DefinePromptAgent, which creates a -// multi-turn conversational agent flow backed by a .prompt file. The +// This sample demonstrates DefineSessionFlowFromPrompt, which creates a +// multi-turn conversational session flow backed by a .prompt file. The // conversation loop (render prompt, call model, stream chunks, update history) // is handled automatically. Compare with custom-agent which wires // the same loop manually. @@ -39,7 +39,7 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatFlow := genkit.DefinePromptAgent( + chatFlow := genkit.DefineSessionFlowFromPrompt( g, "chat", ChatPromptInput{Personality: "a sarcastic pirate"}, aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { @@ -47,7 +47,7 @@ func main() { }), ) - fmt.Println("Prompt Agent Chat (type 'quit' to exit)") + fmt.Println("Session Flow Chat (type 'quit' to exit)") fmt.Println() conn, err := chatFlow.StreamBidi(ctx) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 37464bce79..de05e2f978 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -987,24 +987,24 @@ class Messages(RootModel[list[Message]]): root: list[Message] -class AgentFlowInput(BaseModel): - """Model for agentflowinput data.""" +class SessionFlowInput(BaseModel): + """Model for sessionflowinput data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) messages: list[Message] | None = None tool_restarts: list[Part] | None = Field(default=None) -class AgentFlowResult(BaseModel): - """Model for agentflowresult data.""" +class SessionFlowResult(BaseModel): + """Model for sessionflowresult data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) message: Message | None = None artifacts: list[Artifact] | None = None -class AgentFlowStreamChunk(BaseModel): - """Model for agentflowstreamchunk data.""" +class SessionFlowStreamChunk(BaseModel): + """Model for sessionflowstreamchunk data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) model_chunk: ModelResponseChunk | None = Field(default=None) @@ -1101,16 +1101,16 @@ class Request(RootModel[GenerateRequest]): root: GenerateRequest -class AgentFlowInit(BaseModel): - """Model for agentflowinit data.""" +class SessionFlowInit(BaseModel): + """Model for sessionflowinit data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) snapshot_id: str | None = Field(default=None) state: SessionState | None = None -class AgentFlowOutput(BaseModel): - """Model for agentflowoutput data.""" +class SessionFlowOutput(BaseModel): + """Model for sessionflowoutput data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) snapshot_id: str | None = Field(default=None) From 3e707fcaaf72c7813299a25b2d2af25ee4eaec18 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Mar 2026 10:16:17 -0800 Subject: [PATCH 045/141] Update genkit-schema.json --- genkit-tools/genkit-schema.json | 44 ++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 5dbbab6de0..c5d0f7377c 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -1,6 +1,28 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "$defs": { + "Artifact": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parts": { + "type": "array", + "items": { + "$ref": "#/$defs/Part" + } + }, + "metadata": { + "type": "object", + "additionalProperties": {} + } + }, + "required": [ + "parts" + ], + "additionalProperties": false + }, "SessionFlowInit": { "type": "object", "properties": { @@ -86,28 +108,6 @@ }, "additionalProperties": false }, - "Artifact": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "parts": { - "type": "array", - "items": { - "$ref": "#/$defs/Part" - } - }, - "metadata": { - "type": "object", - "additionalProperties": {} - } - }, - "required": [ - "parts" - ], - "additionalProperties": false - }, "SessionState": { "type": "object", "properties": { From 5ff552025e9e96dccae67cc93a08c37fb0b4b274 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Mar 2026 12:21:38 -0800 Subject: [PATCH 046/141] renamed files --- go/ai/exp/{agent.go => session_flow.go} | 0 go/ai/exp/{agent_test.go => session_flow_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename go/ai/exp/{agent.go => session_flow.go} (100%) rename go/ai/exp/{agent_test.go => session_flow_test.go} (100%) diff --git a/go/ai/exp/agent.go b/go/ai/exp/session_flow.go similarity index 100% rename from go/ai/exp/agent.go rename to go/ai/exp/session_flow.go diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/session_flow_test.go similarity index 100% rename from go/ai/exp/agent_test.go rename to go/ai/exp/session_flow_test.go From 66e5969ca35b6a440a2e23d496e25270065bf035 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Mar 2026 12:27:42 -0800 Subject: [PATCH 047/141] Update agent.ts --- genkit-tools/common/src/types/agent.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index e2f98aeb16..8046fa468c 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -115,4 +115,6 @@ export const SessionFlowStreamChunkSchema = z.object({ /** Signals that the session flow has finished processing the current input. */ endTurn: z.boolean().optional(), }); -export type SessionFlowStreamChunk = z.infer; +export type SessionFlowStreamChunk = z.infer< + typeof SessionFlowStreamChunkSchema +>; From 9babe2046c75322e181e7193f27c6efbcb1877d9 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Mar 2026 14:16:32 -0800 Subject: [PATCH 048/141] Update schemas.config --- go/core/schemas.config | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/go/core/schemas.config b/go/core/schemas.config index 9cafa3b092..e6be6e0f99 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1110,18 +1110,18 @@ GenkitErrorData omit GenkitErrorDataGenkitErrorDetails omit # ============================================================================ -# AGENT FLOW TYPES (generated into ai/x package) +# AGENT FLOW TYPES (generated into ai/exp package) # ============================================================================ -# Package configuration: ai/x directory uses "aix" as Go package name. -ai/x name aix -aix import github.com/firebase/genkit/go/ai +# Package configuration: ai/exp directory uses "exp" as Go package name. +ai/exp name exp +exp import github.com/firebase/genkit/go/ai # ---------------------------------------------------------------------------- # Artifact # ---------------------------------------------------------------------------- -Artifact pkg ai/x +Artifact pkg ai/exp Artifact doc Artifact represents a named collection of parts produced during a session. @@ -1147,7 +1147,7 @@ Metadata contains additional artifact-specific data. # SessionFlowInput # ---------------------------------------------------------------------------- -SessionFlowInput pkg ai/x +SessionFlowInput pkg ai/exp SessionFlowInput doc SessionFlowInput is the input sent to an session flow during a conversation turn. @@ -1170,7 +1170,7 @@ instead of treating Messages as tool responses. # SessionFlowInit # ---------------------------------------------------------------------------- -SessionFlowInit pkg ai/x +SessionFlowInit pkg ai/exp SessionFlowInit typeparams [State any] SessionFlowInit doc @@ -1193,7 +1193,7 @@ Mutually exclusive with SnapshotID. # SessionFlowResult # ---------------------------------------------------------------------------- -SessionFlowResult pkg ai/x +SessionFlowResult pkg ai/exp SessionFlowResult doc SessionFlowResult is the return value from an SessionFlowFunc. @@ -1213,7 +1213,7 @@ Artifacts contains artifacts produced during the session. # SessionFlowOutput # ---------------------------------------------------------------------------- -SessionFlowOutput pkg ai/x +SessionFlowOutput pkg ai/exp SessionFlowOutput typeparams [State any] SessionFlowOutput doc @@ -1245,7 +1245,7 @@ Artifacts contains artifacts produced during the session. # SessionFlowStreamChunk # ---------------------------------------------------------------------------- -SessionFlowStreamChunk pkg ai/x +SessionFlowStreamChunk pkg ai/exp SessionFlowStreamChunk typeparams [Stream any] SessionFlowStreamChunk doc @@ -1281,7 +1281,7 @@ When true, the client should stop iterating and may send the next input. # SessionState # ---------------------------------------------------------------------------- -SessionState pkg ai/x +SessionState pkg ai/exp SessionState typeparams [State any] SessionState doc @@ -1313,7 +1313,7 @@ InputVariables is the input used for session flows that require input variables # SnapshotEvent # ---------------------------------------------------------------------------- -SnapshotEvent pkg ai/x +SnapshotEvent pkg ai/exp SnapshotEvent doc SnapshotEvent identifies what triggered a snapshot. From 06671f26e12cb42851ce9d01a7d795a092b45557 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Mar 2026 10:15:00 -0700 Subject: [PATCH 049/141] refactored type params --- go/core/action.go | 203 +++++++++++++++++++++--------------------- go/core/api/action.go | 4 +- go/core/flow.go | 46 +++++----- go/core/flow_test.go | 8 +- go/genkit/genkit.go | 11 +-- 5 files changed, 139 insertions(+), 133 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 0bd867b4f1..4f8de572ca 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -34,31 +34,34 @@ import ( // Func is an alias for non-streaming functions with input of type In and output of type Out. type Func[In, Out any] = func(context.Context, In) (Out, error) -// StreamingFunc is an alias for streaming functions with input of type In, output of type Out, and streaming chunk of type Stream. -type StreamingFunc[In, Out, Stream any] = func(context.Context, In, StreamCallback[Stream]) (Out, error) +// StreamingFunc is an alias for streaming functions with input of type In, output of type Out, and outgoing stream chunk of type StreamOut. +type StreamingFunc[In, Out, StreamOut any] = func(context.Context, In, StreamCallback[StreamOut]) (Out, error) -// StreamCallback is a function that is called during streaming to return the next chunk of the stream. -type StreamCallback[Stream any] = func(context.Context, Stream) error +// StreamCallback is a function that is called during streaming to return the next chunk of the outgoing stream. +type StreamCallback[StreamOut any] = func(context.Context, StreamOut) error // BidiFunc is the function signature for bidirectional streaming actions. -// It receives initialization data, reads inputs from inCh, and writes -// streamed outputs to outCh. It returns a final output when complete. -type BidiFunc[In, Out, Stream, Init any] = func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) +// It receives an initial input, reads incoming stream messages from inCh, +// and writes outgoing stream messages to outCh. It returns a final output when complete. +type BidiFunc[In, Out, StreamOut, StreamIn any] = func(ctx context.Context, in In, inCh <-chan StreamIn, outCh chan<- StreamOut) (Out, error) // An Action is a named, observable operation that underlies all Genkit primitives. -// It consists of a function that takes an input of type I and returns an output -// of type O, optionally streaming values of type S incrementally by invoking a callback. +// It consists of a function that takes an input of type In and returns an output +// of type Out, optionally streaming values of type StreamOut incrementally by +// invoking a callback. For bidirectional actions, StreamIn is the type of +// incoming stream messages. +// // It optionally has other metadata, like a description and JSON Schemas for its input and // output which it validates against. // // Each time an Action is run, it results in a new trace span. // // For internal use only. -type Action[In, Out, Stream, Init any] struct { - fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. - bidiFn BidiFunc[In, Out, Stream, Init] // Non-nil for bidi actions only. - desc *api.ActionDesc // Descriptor of the action. - registry api.Registry // Registry for schema resolution. Set when registered. +type Action[In, Out, StreamOut, StreamIn any] struct { + fn StreamingFunc[In, Out, StreamOut] // Function that is called during runtime. May not actually support streaming. + bidiFn BidiFunc[In, Out, StreamOut, StreamIn] // Non-nil for bidi actions only. + desc *api.ActionDesc // Descriptor of the action. + registry api.Registry // Registry for schema resolution. Set when registered. } type noStream = func(context.Context, struct{}) error @@ -80,32 +83,32 @@ func NewAction[In, Out any]( // NewStreamingAction creates a new streaming [Action] without registering it. // If inputSchema is nil, it is inferred from the function's input api. -func NewStreamingAction[In, Out, Stream any]( +func NewStreamingAction[In, Out, StreamOut any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, struct{}] { - return newAction[In, Out, Stream, struct{}](name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, struct{}] { + return newAction[In, Out, StreamOut, struct{}](name, atype, metadata, inputSchema, fn) } // ActionOptions configures a bidi action. Nil schema fields are inferred from type parameters. type ActionOptions struct { - Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. - InputSchema map[string]any // JSON schema for the action's input. Inferred from In if nil. - OutputSchema map[string]any // JSON schema for the action's output. Inferred from Out if nil. - StreamSchema map[string]any // JSON schema for streamed chunks. Inferred from Stream if nil. Not used for non-streaming actions. - InitSchema map[string]any // JSON schema for bidi initialization data. Inferred from Init if nil. Not used for non-bidi actions. + Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. + InputSchema map[string]any // JSON schema for the action's input. Inferred from In if nil. + OutputSchema map[string]any // JSON schema for the action's output. Inferred from Out if nil. + StreamOutSchema map[string]any // JSON schema for outgoing streamed chunks. Inferred from StreamOut if nil. Not used for non-streaming actions. + StreamInSchema map[string]any // JSON schema for incoming stream messages. Inferred from StreamIn if nil. Not used for non-bidi actions. } // NewBidiAction creates a new bidirectional streaming [Action] without registering it. -func NewBidiAction[In, Out, Stream, Init any]( +func NewBidiAction[In, Out, StreamOut, StreamIn any]( name string, atype api.ActionType, opts *ActionOptions, - fn BidiFunc[In, Out, Stream, Init], -) *Action[In, Out, Stream, Init] { + fn BidiFunc[In, Out, StreamOut, StreamIn], +) *Action[In, Out, StreamOut, StreamIn] { if opts == nil { opts = &ActionOptions{} } @@ -116,22 +119,22 @@ func NewBidiAction[In, Out, Stream, Init any]( } metadata["bidi"] = true - a := newAction[In, Out, Stream, Init](name, atype, metadata, opts.InputSchema, wrapBidiAsStreaming(fn)) + a := newAction[In, Out, StreamOut, StreamIn](name, atype, metadata, opts.InputSchema, wrapBidiAsStreaming(fn)) a.bidiFn = fn if opts.OutputSchema != nil { a.desc.OutputSchema = opts.OutputSchema } - if opts.StreamSchema != nil { - a.desc.StreamSchema = opts.StreamSchema + if opts.StreamOutSchema != nil { + a.desc.StreamOutSchema = opts.StreamOutSchema } - if opts.InitSchema != nil { - a.desc.InitSchema = opts.InitSchema + if opts.StreamInSchema != nil { + a.desc.StreamInSchema = opts.StreamInSchema } else { - var init Init - if reflect.ValueOf(init).Kind() != reflect.Invalid { - a.desc.InitSchema = InferSchemaMap(init) + var inStream StreamIn + if reflect.ValueOf(inStream).Kind() != reflect.Invalid { + a.desc.StreamInSchema = InferSchemaMap(inStream) } } @@ -139,13 +142,13 @@ func NewBidiAction[In, Out, Stream, Init any]( } // DefineBidiAction creates and registers a bidirectional streaming [Action]. -func DefineBidiAction[In, Out, Stream, Init any]( +func DefineBidiAction[In, Out, StreamOut, StreamIn any]( r api.Registry, name string, atype api.ActionType, opts *ActionOptions, - fn BidiFunc[In, Out, Stream, Init], -) *Action[In, Out, Stream, Init] { + fn BidiFunc[In, Out, StreamOut, StreamIn], +) *Action[In, Out, StreamOut, StreamIn] { a := NewBidiAction(name, atype, opts, fn) a.Register(r) return a @@ -169,27 +172,27 @@ func DefineAction[In, Out any]( // DefineStreamingAction creates a new streaming action and registers it. // If inputSchema is nil, it is inferred from the function's input api. -func DefineStreamingAction[In, Out, Stream any]( +func DefineStreamingAction[In, Out, StreamOut any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, struct{}] { - return defineAction[In, Out, Stream, struct{}](r, name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, struct{}] { + return defineAction[In, Out, StreamOut, struct{}](r, name, atype, metadata, inputSchema, fn) } // defineAction creates an action and registers it with the given Registry. -func defineAction[In, Out, Stream, Init any]( +func defineAction[In, Out, StreamOut, StreamIn any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, Init] { - a := newAction[In, Out, Stream, Init](name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, StreamIn] { + a := newAction[In, Out, StreamOut, StreamIn](name, atype, metadata, inputSchema, fn) a.Register(r) return a } @@ -197,13 +200,13 @@ func defineAction[In, Out, Stream, Init any]( // newAction creates a new Action with the given name and arguments. // If registry is nil, tracing state is left nil to be set later. // If inputSchema is nil, it is inferred from In. -func newAction[In, Out, Stream, Init any]( +func newAction[In, Out, StreamOut, StreamIn any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, Init] { + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, StreamIn] { if inputSchema == nil { var i In if reflect.ValueOf(i).Kind() != reflect.Invalid { @@ -217,10 +220,10 @@ func newAction[In, Out, Stream, Init any]( outputSchema = InferSchemaMap(o) } - var s Stream - var streamSchema map[string]any + var s StreamOut + var outStreamSchema map[string]any if reflect.ValueOf(s).Kind() != reflect.Invalid { - streamSchema = InferSchemaMap(s) + outStreamSchema = InferSchemaMap(s) } var description string @@ -228,28 +231,28 @@ func newAction[In, Out, Stream, Init any]( description = desc } - return &Action[In, Out, Stream, Init]{ - fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { + return &Action[In, Out, StreamOut, StreamIn]{ + fn: func(ctx context.Context, input In, cb StreamCallback[StreamOut]) (Out, error) { return fn(ctx, input, cb) }, desc: &api.ActionDesc{ - Type: atype, - Key: api.KeyFromName(atype, name), - Name: name, - Description: description, - InputSchema: inputSchema, - OutputSchema: outputSchema, - StreamSchema: streamSchema, - Metadata: metadata, + Type: atype, + Key: api.KeyFromName(atype, name), + Name: name, + Description: description, + InputSchema: inputSchema, + OutputSchema: outputSchema, + StreamOutSchema: outStreamSchema, + Metadata: metadata, }, } } // Name returns the Action's Name. -func (a *Action[In, Out, Stream, Init]) Name() string { return a.desc.Name } +func (a *Action[In, Out, StreamOut, StreamIn]) Name() string { return a.desc.Name } // Run executes the Action's function in a new trace span. -func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { +func (a *Action[In, Out, StreamOut, StreamIn]) Run(ctx context.Context, input In, cb StreamCallback[StreamOut]) (output Out, err error) { r, err := a.runWithTelemetry(ctx, input, cb) if err != nil { return base.Zero[Out](), err @@ -258,7 +261,7 @@ func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb St } // runWithTelemetry executes the Action's function in a new trace span and returns telemetry info. -func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { +func (a *Action[In, Out, StreamOut, StreamIn]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[StreamOut]) (output api.ActionRunResult[Out], err error) { logger.FromContext(ctx).Debug("Action.Run", "name", a.Name()) defer func() { logger.FromContext(ctx).Debug("Action.Run", @@ -334,7 +337,7 @@ func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, in } // RunJSON runs the action with a JSON input, and returns a JSON result. -func (a *Action[In, Out, Stream, Init]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { +func (a *Action[In, Out, StreamOut, StreamIn]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { r, err := a.RunJSONWithTelemetry(ctx, input, cb) if err != nil { return nil, err @@ -343,15 +346,15 @@ func (a *Action[In, Out, Stream, Init]) RunJSON(ctx context.Context, input json. } // RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info. -func (a *Action[In, Out, Stream, Init]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Action[In, Out, StreamOut, StreamIn]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { return nil, NewError(INVALID_ARGUMENT, err.Error()) } - var scb StreamCallback[Stream] + var scb StreamCallback[StreamOut] if cb != nil { - scb = func(ctx context.Context, s Stream) error { + scb = func(ctx context.Context, s StreamOut) error { bytes, err := json.Marshal(s) if err != nil { return err @@ -381,12 +384,12 @@ func (a *Action[In, Out, Stream, Init]) RunJSONWithTelemetry(ctx context.Context } // Desc returns a descriptor of the action with resolved schema references. -func (a *Action[In, Out, Stream, Init]) Desc() api.ActionDesc { +func (a *Action[In, Out, StreamOut, StreamIn]) Desc() api.ActionDesc { return *a.desc } // Register registers the action with the given registry. -func (a *Action[In, Out, Stream, Init]) Register(r api.Registry) { +func (a *Action[In, Out, StreamOut, StreamIn]) Register(r api.Registry) { a.registry = r r.RegisterAction(a.desc.Key, a) } @@ -394,15 +397,15 @@ func (a *Action[In, Out, Stream, Init]) Register(r api.Registry) { // StreamBidi starts a bidirectional streaming connection. // Returns an error if the action is not a bidi action. // A trace span is created that remains open for the lifetime of the connection. -func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) { +func (a *Action[In, Out, StreamOut, StreamIn]) StreamBidi(ctx context.Context, in In) (*BidiConnection[StreamIn, Out, StreamOut], error) { if a.bidiFn == nil { return nil, NewError(FAILED_PRECONDITION, "StreamBidi called on non-bidi action %q", a.desc.Name) } ctx, cancel := context.WithCancel(ctx) - conn := &BidiConnection[In, Out, Stream]{ - inputCh: make(chan In, 1), - streamCh: make(chan Stream, 1), + conn := &BidiConnection[StreamIn, Out, StreamOut]{ + inputCh: make(chan StreamIn, 1), + streamCh: make(chan StreamOut, 1), doneCh: make(chan struct{}), ctx: ctx, cancel: cancel, @@ -421,10 +424,10 @@ func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Ini go func() { defer close(conn.doneCh) defer close(conn.streamCh) - output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, init, - func(ctx context.Context, init Init) (Out, error) { + output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, in, + func(ctx context.Context, in In) (Out, error) { start := time.Now() - output, err := a.bidiFn(ctx, init, conn.inputCh, conn.streamCh) + output, err := a.bidiFn(ctx, in, conn.inputCh, conn.streamCh) latency := time.Since(start) if err != nil { metrics.WriteActionFailure(ctx, a.desc.Name, latency, err) @@ -446,14 +449,14 @@ func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Ini // ResolveActionFor returns the action for the given key in the global registry, // or nil if there is none. // It panics if the action is of the wrong api. -func ResolveActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { +func ResolveActionFor[In, Out, StreamOut, StreamIn any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, StreamOut, StreamIn] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.ResolveAction(key) if a == nil { return nil } - return a.(*Action[In, Out, Stream, Init]) + return a.(*Action[In, Out, StreamOut, StreamIn]) } // LookupActionFor returns the action for the given key in the global registry, @@ -461,22 +464,24 @@ func ResolveActionFor[In, Out, Stream, Init any](r api.Registry, atype api.Actio // It panics if the action is of the wrong api. // // Deprecated: Use ResolveActionFor. -func LookupActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { +func LookupActionFor[In, Out, StreamOut, StreamIn any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, StreamOut, StreamIn] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.LookupAction(key) if a == nil { return nil } - return a.(*Action[In, Out, Stream, Init]) + return a.(*Action[In, Out, StreamOut, StreamIn]) } // wrapBidiAsStreaming wraps a BidiFunc into a StreamingFunc for use with Run/RunJSON. -// The input is sent as a single message, and stream chunks are forwarded to the callback. -func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, Init]) StreamingFunc[In, Out, Stream] { - return func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { - inCh := make(chan In, 1) - outCh := make(chan Stream, 1) +// The input is passed as the initial input to the bidi func, and the input stream +// channel is closed immediately (no streaming inputs). Outgoing stream chunks are +// forwarded to the callback. +func wrapBidiAsStreaming[In, Out, StreamOut, StreamIn any](fn BidiFunc[In, Out, StreamOut, StreamIn]) StreamingFunc[In, Out, StreamOut] { + return func(ctx context.Context, input In, cb StreamCallback[StreamOut]) (Out, error) { + inCh := make(chan StreamIn, 1) + outCh := make(chan StreamOut, 1) doneCh := make(chan struct{}) var output Out @@ -485,12 +490,10 @@ func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, go func() { defer close(doneCh) defer close(outCh) - var init Init - output, fnErr = fn(ctx, init, inCh, outCh) + output, fnErr = fn(ctx, input, inCh, outCh) }() - // Send the single input and close. - inCh <- input + // No streaming inputs when used as a non-bidi streaming action. close(inCh) // Forward streamed chunks to the callback. @@ -512,9 +515,9 @@ func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, } // BidiConnection represents an active bidirectional streaming session. -type BidiConnection[In, Out, Stream any] struct { - inputCh chan In - streamCh chan Stream +type BidiConnection[StreamIn, Out, StreamOut any] struct { + inputCh chan StreamIn + streamCh chan StreamOut doneCh chan struct{} output Out err error @@ -526,7 +529,7 @@ type BidiConnection[In, Out, Stream any] struct { // Send sends an input message to the bidi action. // Returns an error if the connection is closed or the context is cancelled. -func (c *BidiConnection[In, Out, Stream]) Send(input In) (err error) { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Send(input StreamIn) (err error) { defer func() { if r := recover(); r != nil { err = NewError(FAILED_PRECONDITION, "connection is closed") @@ -544,7 +547,7 @@ func (c *BidiConnection[In, Out, Stream]) Send(input In) (err error) { } // Close signals that no more inputs will be sent. -func (c *BidiConnection[In, Out, Stream]) Close() error { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { @@ -557,8 +560,8 @@ func (c *BidiConnection[In, Out, Stream]) Close() error { // Receive returns an iterator for receiving streamed response chunks. // The iterator completes when the action finishes. -func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { - return func(yield func(Stream, error) bool) { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Receive() iter.Seq2[StreamOut, error] { + return func(yield func(StreamOut, error) bool) { for { select { case chunk, ok := <-c.streamCh: @@ -570,7 +573,7 @@ func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { return } case <-c.ctx.Done(): - var zero Stream + var zero StreamOut yield(zero, c.ctx.Err()) return } @@ -580,7 +583,7 @@ func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { // Output returns the final output after the action completes. // Blocks until done or context cancelled. -func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Output() (Out, error) { select { case <-c.doneCh: c.mu.Lock() @@ -593,6 +596,6 @@ func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) { } // Done returns a channel that is closed when the connection completes. -func (c *BidiConnection[In, Out, Stream]) Done() <-chan struct{} { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Done() <-chan struct{} { return c.doneCh } diff --git a/go/core/api/action.go b/go/core/api/action.go index a38958af51..336e6145ce 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -74,7 +74,7 @@ type ActionDesc struct { Description string `json:"description"` // Description of the action. InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. - StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema to validate against the action's streamed chunks. - InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema to validate against the action's initialization data. + StreamOutSchema map[string]any `json:"streamOutSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. + StreamInSchema map[string]any `json:"streamInSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). Metadata map[string]any `json:"metadata"` // Metadata for the action. } diff --git a/go/core/flow.go b/go/core/flow.go index 3ad8bba60d..e10fb34344 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -26,17 +26,17 @@ import ( "github.com/firebase/genkit/go/internal/base" ) -// A Flow is a user-defined Action. A Flow[In, Out, Stream, Init] represents a function from In to Out. -// The Stream parameter is for flows that support streaming: providing their results incrementally. The Init parameter is for bidi flows. -type Flow[In, Out, Stream, Init any] struct { - *Action[In, Out, Stream, Init] +// A Flow is a user-defined Action. A Flow[In, Out, StreamOut, StreamIn] represents a function from In to Out. +// The StreamOut parameter is for flows that support streaming: providing their results incrementally. The StreamIn parameter is for bidi flows. +type Flow[In, Out, StreamOut, StreamIn any] struct { + *Action[In, Out, StreamOut, StreamIn] } // StreamingFlowValue is either a streamed value or a final output of a flow. -type StreamingFlowValue[Out, Stream any] struct { +type StreamingFlowValue[Out, StreamOut any] struct { Done bool - Output Out // valid if Done is true - Stream Stream // valid if Done is false + Output Out // valid if Done is true + Stream StreamOut // valid if Done is false } // flowContextKey is a context key that indicates whether the current context is a flow context. @@ -59,14 +59,14 @@ func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{} } // NewStreamingFlow creates a streaming Flow that runs fn without registering it. -func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { - return &Flow[In, Out, Stream, struct{}]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { +func NewStreamingFlow[In, Out, StreamOut any](name string, fn StreamingFunc[In, Out, StreamOut]) *Flow[In, Out, StreamOut, struct{}] { + return &Flow[In, Out, StreamOut, struct{}]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, StreamOut) error) (Out, error) { fc := &flowContext{ flowName: name, } ctx = flowContextKey.NewContext(ctx, fc) if cb == nil { - cb = func(context.Context, Stream) error { return nil } + cb = func(context.Context, StreamOut) error { return nil } } return fn(ctx, input, cb) })} @@ -74,12 +74,12 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out // NewBidiFlow creates a bidirectional streaming Flow without registering it. // Flow context is injected so that [Run] works inside the bidi function. -func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { - wrapped := func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) { +func NewBidiFlow[In, Out, StreamOut, StreamIn any](name string, fn BidiFunc[In, Out, StreamOut, StreamIn]) *Flow[In, Out, StreamOut, StreamIn] { + wrapped := func(ctx context.Context, in In, inCh <-chan StreamIn, outCh chan<- StreamOut) (Out, error) { ctx = flowContextKey.NewContext(ctx, &flowContext{flowName: name}) - return fn(ctx, init, inCh, outCh) + return fn(ctx, in, inCh, outCh) } - return &Flow[In, Out, Stream, Init]{NewBidiAction(name, api.ActionTypeFlow, nil, wrapped)} + return &Flow[In, Out, StreamOut, StreamIn]{NewBidiAction(name, api.ActionTypeFlow, nil, wrapped)} } // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. @@ -92,13 +92,13 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. // // fn takes an input of type In and returns an output of type Out, optionally -// streaming values of type Stream incrementally by invoking a callback. +// streaming values of type StreamOut incrementally by invoking a callback. // // If the function supports streaming and the callback is non-nil, it should // stream the results by invoking the callback periodically, ultimately returning // with a final return value that includes all the streamed data. // Otherwise, it should ignore the callback and just return a result. -func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { +func DefineStreamingFlow[In, Out, StreamOut any](r api.Registry, name string, fn StreamingFunc[In, Out, StreamOut]) *Flow[In, Out, StreamOut, struct{}] { f := NewStreamingFlow(name, fn) f.Register(r) return f @@ -106,7 +106,7 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St // DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. // Flow context is injected so that [Run] works inside the bidi function. -func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { +func DefineBidiFlow[In, Out, StreamOut, StreamIn any](r api.Registry, name string, fn BidiFunc[In, Out, StreamOut, StreamIn]) *Flow[In, Out, StreamOut, StreamIn] { f := NewBidiFlow(name, fn) f.Register(r) return f @@ -140,7 +140,7 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out } // Run runs the flow in the context of another flow. -func (f *Flow[In, Out, Stream, Init]) Run(ctx context.Context, input In) (Out, error) { +func (f *Flow[In, Out, StreamOut, StreamIn]) Run(ctx context.Context, input In) (Out, error) { return f.Action.Run(ctx, input, nil) } @@ -156,13 +156,13 @@ func (f *Flow[In, Out, Stream, Init]) Run(ctx context.Context, input In) (Out, e // again. // // Otherwise the Stream field of the passed [StreamingFlowValue] holds a streamed result. -func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, Stream], error) bool) { - return func(yield func(*StreamingFlowValue[Out, Stream], error) bool) { - cb := func(ctx context.Context, s Stream) error { +func (f *Flow[In, Out, StreamOut, StreamIn]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, StreamOut], error) bool) { + return func(yield func(*StreamingFlowValue[Out, StreamOut], error) bool) { + cb := func(ctx context.Context, s StreamOut) error { if ctx.Err() != nil { return ctx.Err() } - if !yield(&StreamingFlowValue[Out, Stream]{Stream: s}, nil) { + if !yield(&StreamingFlowValue[Out, StreamOut]{Stream: s}, nil) { return errStop } return nil @@ -175,7 +175,7 @@ func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func if err != nil { yield(nil, err) } else { - yield(&StreamingFlowValue[Out, Stream]{Done: true, Output: output}, nil) + yield(&StreamingFlowValue[Out, StreamOut]{Done: true, Output: output}, nil) } } } diff --git a/go/core/flow_test.go b/go/core/flow_test.go index 7da8d31778..5810ac4625 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -264,6 +264,7 @@ func TestFlowNameFromContextOutsideFlow(t *testing.T) { func TestBidiActionEcho(t *testing.T) { ctx := context.Background() + // In=struct{} (no initial data), Out=string, OutStream=string, InStream=string action := NewBidiAction( "echo", api.ActionTypeCustom, nil, func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { @@ -315,18 +316,19 @@ func TestBidiActionEcho(t *testing.T) { } } -func TestBidiActionWithInit(t *testing.T) { +func TestBidiActionWithConfig(t *testing.T) { ctx := context.Background() type Config struct { Prefix string } + // In=Config (initial config), Out=string, OutStream=string, InStream=string action := NewBidiAction( "prefixed", api.ActionTypeCustom, nil, - func(ctx context.Context, init Config, inCh <-chan string, outCh chan<- string) (string, error) { + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { for input := range inCh { - outCh <- fmt.Sprintf("%s: %s", init.Prefix, input) + outCh <- fmt.Sprintf("%s: %s", cfg.Prefix, input) } return "done", nil }, diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 8adc19f69e..ed8be96776 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -355,14 +355,15 @@ func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.St // DefineBidiFlow defines a bidirectional streaming flow, registers it as a [core.Action] of type Flow, // and returns a [core.Flow] capable of bidirectional streaming. // -// The provided function `fn` receives initialization data of type `Init`, reads -// inputs of type `In` from an input channel, and writes streamed outputs of type -// `Stream` to an output channel. It returns a final output of type `Out` when complete. +// The provided function `fn` receives an initial input of type `In`, reads +// incoming stream messages of type `StreamIn` from an input channel, and writes +// outgoing stream messages of type `StreamOut` to an output channel. It returns +// a final output of type `Out` when complete. // // Example: // // chatFlow := genkit.DefineBidiFlow(g, "chat", -// func(ctx context.Context, init struct{}, inCh <-chan string, outCh chan<- string) (string, error) { +// func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { // var count int // for input := range inCh { // count++ @@ -396,7 +397,7 @@ func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.St // // Get the final output: // output, err := conn.Output() // fmt.Println(output) // Output: "processed 2 messages" -func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.BidiFunc[In, Out, Stream, Init]) *core.Flow[In, Out, Stream, Init] { +func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn core.BidiFunc[In, Out, StreamOut, StreamIn]) *core.Flow[In, Out, StreamOut, StreamIn] { return core.DefineBidiFlow(g.reg, name, fn) } From 2b132574a659d8bb6f8aae20b06bd9f52788ae48 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Mar 2026 10:15:00 -0700 Subject: [PATCH 050/141] refactored type params --- go/core/action.go | 203 +++++++++++++++++++++--------------------- go/core/api/action.go | 4 +- go/core/flow.go | 46 +++++----- go/core/flow_test.go | 8 +- go/genkit/genkit.go | 11 +-- 5 files changed, 139 insertions(+), 133 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 0bd867b4f1..4f8de572ca 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -34,31 +34,34 @@ import ( // Func is an alias for non-streaming functions with input of type In and output of type Out. type Func[In, Out any] = func(context.Context, In) (Out, error) -// StreamingFunc is an alias for streaming functions with input of type In, output of type Out, and streaming chunk of type Stream. -type StreamingFunc[In, Out, Stream any] = func(context.Context, In, StreamCallback[Stream]) (Out, error) +// StreamingFunc is an alias for streaming functions with input of type In, output of type Out, and outgoing stream chunk of type StreamOut. +type StreamingFunc[In, Out, StreamOut any] = func(context.Context, In, StreamCallback[StreamOut]) (Out, error) -// StreamCallback is a function that is called during streaming to return the next chunk of the stream. -type StreamCallback[Stream any] = func(context.Context, Stream) error +// StreamCallback is a function that is called during streaming to return the next chunk of the outgoing stream. +type StreamCallback[StreamOut any] = func(context.Context, StreamOut) error // BidiFunc is the function signature for bidirectional streaming actions. -// It receives initialization data, reads inputs from inCh, and writes -// streamed outputs to outCh. It returns a final output when complete. -type BidiFunc[In, Out, Stream, Init any] = func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) +// It receives an initial input, reads incoming stream messages from inCh, +// and writes outgoing stream messages to outCh. It returns a final output when complete. +type BidiFunc[In, Out, StreamOut, StreamIn any] = func(ctx context.Context, in In, inCh <-chan StreamIn, outCh chan<- StreamOut) (Out, error) // An Action is a named, observable operation that underlies all Genkit primitives. -// It consists of a function that takes an input of type I and returns an output -// of type O, optionally streaming values of type S incrementally by invoking a callback. +// It consists of a function that takes an input of type In and returns an output +// of type Out, optionally streaming values of type StreamOut incrementally by +// invoking a callback. For bidirectional actions, StreamIn is the type of +// incoming stream messages. +// // It optionally has other metadata, like a description and JSON Schemas for its input and // output which it validates against. // // Each time an Action is run, it results in a new trace span. // // For internal use only. -type Action[In, Out, Stream, Init any] struct { - fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. - bidiFn BidiFunc[In, Out, Stream, Init] // Non-nil for bidi actions only. - desc *api.ActionDesc // Descriptor of the action. - registry api.Registry // Registry for schema resolution. Set when registered. +type Action[In, Out, StreamOut, StreamIn any] struct { + fn StreamingFunc[In, Out, StreamOut] // Function that is called during runtime. May not actually support streaming. + bidiFn BidiFunc[In, Out, StreamOut, StreamIn] // Non-nil for bidi actions only. + desc *api.ActionDesc // Descriptor of the action. + registry api.Registry // Registry for schema resolution. Set when registered. } type noStream = func(context.Context, struct{}) error @@ -80,32 +83,32 @@ func NewAction[In, Out any]( // NewStreamingAction creates a new streaming [Action] without registering it. // If inputSchema is nil, it is inferred from the function's input api. -func NewStreamingAction[In, Out, Stream any]( +func NewStreamingAction[In, Out, StreamOut any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, struct{}] { - return newAction[In, Out, Stream, struct{}](name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, struct{}] { + return newAction[In, Out, StreamOut, struct{}](name, atype, metadata, inputSchema, fn) } // ActionOptions configures a bidi action. Nil schema fields are inferred from type parameters. type ActionOptions struct { - Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. - InputSchema map[string]any // JSON schema for the action's input. Inferred from In if nil. - OutputSchema map[string]any // JSON schema for the action's output. Inferred from Out if nil. - StreamSchema map[string]any // JSON schema for streamed chunks. Inferred from Stream if nil. Not used for non-streaming actions. - InitSchema map[string]any // JSON schema for bidi initialization data. Inferred from Init if nil. Not used for non-bidi actions. + Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. + InputSchema map[string]any // JSON schema for the action's input. Inferred from In if nil. + OutputSchema map[string]any // JSON schema for the action's output. Inferred from Out if nil. + StreamOutSchema map[string]any // JSON schema for outgoing streamed chunks. Inferred from StreamOut if nil. Not used for non-streaming actions. + StreamInSchema map[string]any // JSON schema for incoming stream messages. Inferred from StreamIn if nil. Not used for non-bidi actions. } // NewBidiAction creates a new bidirectional streaming [Action] without registering it. -func NewBidiAction[In, Out, Stream, Init any]( +func NewBidiAction[In, Out, StreamOut, StreamIn any]( name string, atype api.ActionType, opts *ActionOptions, - fn BidiFunc[In, Out, Stream, Init], -) *Action[In, Out, Stream, Init] { + fn BidiFunc[In, Out, StreamOut, StreamIn], +) *Action[In, Out, StreamOut, StreamIn] { if opts == nil { opts = &ActionOptions{} } @@ -116,22 +119,22 @@ func NewBidiAction[In, Out, Stream, Init any]( } metadata["bidi"] = true - a := newAction[In, Out, Stream, Init](name, atype, metadata, opts.InputSchema, wrapBidiAsStreaming(fn)) + a := newAction[In, Out, StreamOut, StreamIn](name, atype, metadata, opts.InputSchema, wrapBidiAsStreaming(fn)) a.bidiFn = fn if opts.OutputSchema != nil { a.desc.OutputSchema = opts.OutputSchema } - if opts.StreamSchema != nil { - a.desc.StreamSchema = opts.StreamSchema + if opts.StreamOutSchema != nil { + a.desc.StreamOutSchema = opts.StreamOutSchema } - if opts.InitSchema != nil { - a.desc.InitSchema = opts.InitSchema + if opts.StreamInSchema != nil { + a.desc.StreamInSchema = opts.StreamInSchema } else { - var init Init - if reflect.ValueOf(init).Kind() != reflect.Invalid { - a.desc.InitSchema = InferSchemaMap(init) + var inStream StreamIn + if reflect.ValueOf(inStream).Kind() != reflect.Invalid { + a.desc.StreamInSchema = InferSchemaMap(inStream) } } @@ -139,13 +142,13 @@ func NewBidiAction[In, Out, Stream, Init any]( } // DefineBidiAction creates and registers a bidirectional streaming [Action]. -func DefineBidiAction[In, Out, Stream, Init any]( +func DefineBidiAction[In, Out, StreamOut, StreamIn any]( r api.Registry, name string, atype api.ActionType, opts *ActionOptions, - fn BidiFunc[In, Out, Stream, Init], -) *Action[In, Out, Stream, Init] { + fn BidiFunc[In, Out, StreamOut, StreamIn], +) *Action[In, Out, StreamOut, StreamIn] { a := NewBidiAction(name, atype, opts, fn) a.Register(r) return a @@ -169,27 +172,27 @@ func DefineAction[In, Out any]( // DefineStreamingAction creates a new streaming action and registers it. // If inputSchema is nil, it is inferred from the function's input api. -func DefineStreamingAction[In, Out, Stream any]( +func DefineStreamingAction[In, Out, StreamOut any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, struct{}] { - return defineAction[In, Out, Stream, struct{}](r, name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, struct{}] { + return defineAction[In, Out, StreamOut, struct{}](r, name, atype, metadata, inputSchema, fn) } // defineAction creates an action and registers it with the given Registry. -func defineAction[In, Out, Stream, Init any]( +func defineAction[In, Out, StreamOut, StreamIn any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, Init] { - a := newAction[In, Out, Stream, Init](name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, StreamIn] { + a := newAction[In, Out, StreamOut, StreamIn](name, atype, metadata, inputSchema, fn) a.Register(r) return a } @@ -197,13 +200,13 @@ func defineAction[In, Out, Stream, Init any]( // newAction creates a new Action with the given name and arguments. // If registry is nil, tracing state is left nil to be set later. // If inputSchema is nil, it is inferred from In. -func newAction[In, Out, Stream, Init any]( +func newAction[In, Out, StreamOut, StreamIn any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, Stream], -) *Action[In, Out, Stream, Init] { + fn StreamingFunc[In, Out, StreamOut], +) *Action[In, Out, StreamOut, StreamIn] { if inputSchema == nil { var i In if reflect.ValueOf(i).Kind() != reflect.Invalid { @@ -217,10 +220,10 @@ func newAction[In, Out, Stream, Init any]( outputSchema = InferSchemaMap(o) } - var s Stream - var streamSchema map[string]any + var s StreamOut + var outStreamSchema map[string]any if reflect.ValueOf(s).Kind() != reflect.Invalid { - streamSchema = InferSchemaMap(s) + outStreamSchema = InferSchemaMap(s) } var description string @@ -228,28 +231,28 @@ func newAction[In, Out, Stream, Init any]( description = desc } - return &Action[In, Out, Stream, Init]{ - fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { + return &Action[In, Out, StreamOut, StreamIn]{ + fn: func(ctx context.Context, input In, cb StreamCallback[StreamOut]) (Out, error) { return fn(ctx, input, cb) }, desc: &api.ActionDesc{ - Type: atype, - Key: api.KeyFromName(atype, name), - Name: name, - Description: description, - InputSchema: inputSchema, - OutputSchema: outputSchema, - StreamSchema: streamSchema, - Metadata: metadata, + Type: atype, + Key: api.KeyFromName(atype, name), + Name: name, + Description: description, + InputSchema: inputSchema, + OutputSchema: outputSchema, + StreamOutSchema: outStreamSchema, + Metadata: metadata, }, } } // Name returns the Action's Name. -func (a *Action[In, Out, Stream, Init]) Name() string { return a.desc.Name } +func (a *Action[In, Out, StreamOut, StreamIn]) Name() string { return a.desc.Name } // Run executes the Action's function in a new trace span. -func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { +func (a *Action[In, Out, StreamOut, StreamIn]) Run(ctx context.Context, input In, cb StreamCallback[StreamOut]) (output Out, err error) { r, err := a.runWithTelemetry(ctx, input, cb) if err != nil { return base.Zero[Out](), err @@ -258,7 +261,7 @@ func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb St } // runWithTelemetry executes the Action's function in a new trace span and returns telemetry info. -func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { +func (a *Action[In, Out, StreamOut, StreamIn]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[StreamOut]) (output api.ActionRunResult[Out], err error) { logger.FromContext(ctx).Debug("Action.Run", "name", a.Name()) defer func() { logger.FromContext(ctx).Debug("Action.Run", @@ -334,7 +337,7 @@ func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, in } // RunJSON runs the action with a JSON input, and returns a JSON result. -func (a *Action[In, Out, Stream, Init]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { +func (a *Action[In, Out, StreamOut, StreamIn]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { r, err := a.RunJSONWithTelemetry(ctx, input, cb) if err != nil { return nil, err @@ -343,15 +346,15 @@ func (a *Action[In, Out, Stream, Init]) RunJSON(ctx context.Context, input json. } // RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info. -func (a *Action[In, Out, Stream, Init]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Action[In, Out, StreamOut, StreamIn]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { return nil, NewError(INVALID_ARGUMENT, err.Error()) } - var scb StreamCallback[Stream] + var scb StreamCallback[StreamOut] if cb != nil { - scb = func(ctx context.Context, s Stream) error { + scb = func(ctx context.Context, s StreamOut) error { bytes, err := json.Marshal(s) if err != nil { return err @@ -381,12 +384,12 @@ func (a *Action[In, Out, Stream, Init]) RunJSONWithTelemetry(ctx context.Context } // Desc returns a descriptor of the action with resolved schema references. -func (a *Action[In, Out, Stream, Init]) Desc() api.ActionDesc { +func (a *Action[In, Out, StreamOut, StreamIn]) Desc() api.ActionDesc { return *a.desc } // Register registers the action with the given registry. -func (a *Action[In, Out, Stream, Init]) Register(r api.Registry) { +func (a *Action[In, Out, StreamOut, StreamIn]) Register(r api.Registry) { a.registry = r r.RegisterAction(a.desc.Key, a) } @@ -394,15 +397,15 @@ func (a *Action[In, Out, Stream, Init]) Register(r api.Registry) { // StreamBidi starts a bidirectional streaming connection. // Returns an error if the action is not a bidi action. // A trace span is created that remains open for the lifetime of the connection. -func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) { +func (a *Action[In, Out, StreamOut, StreamIn]) StreamBidi(ctx context.Context, in In) (*BidiConnection[StreamIn, Out, StreamOut], error) { if a.bidiFn == nil { return nil, NewError(FAILED_PRECONDITION, "StreamBidi called on non-bidi action %q", a.desc.Name) } ctx, cancel := context.WithCancel(ctx) - conn := &BidiConnection[In, Out, Stream]{ - inputCh: make(chan In, 1), - streamCh: make(chan Stream, 1), + conn := &BidiConnection[StreamIn, Out, StreamOut]{ + inputCh: make(chan StreamIn, 1), + streamCh: make(chan StreamOut, 1), doneCh: make(chan struct{}), ctx: ctx, cancel: cancel, @@ -421,10 +424,10 @@ func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Ini go func() { defer close(conn.doneCh) defer close(conn.streamCh) - output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, init, - func(ctx context.Context, init Init) (Out, error) { + output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, in, + func(ctx context.Context, in In) (Out, error) { start := time.Now() - output, err := a.bidiFn(ctx, init, conn.inputCh, conn.streamCh) + output, err := a.bidiFn(ctx, in, conn.inputCh, conn.streamCh) latency := time.Since(start) if err != nil { metrics.WriteActionFailure(ctx, a.desc.Name, latency, err) @@ -446,14 +449,14 @@ func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Ini // ResolveActionFor returns the action for the given key in the global registry, // or nil if there is none. // It panics if the action is of the wrong api. -func ResolveActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { +func ResolveActionFor[In, Out, StreamOut, StreamIn any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, StreamOut, StreamIn] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.ResolveAction(key) if a == nil { return nil } - return a.(*Action[In, Out, Stream, Init]) + return a.(*Action[In, Out, StreamOut, StreamIn]) } // LookupActionFor returns the action for the given key in the global registry, @@ -461,22 +464,24 @@ func ResolveActionFor[In, Out, Stream, Init any](r api.Registry, atype api.Actio // It panics if the action is of the wrong api. // // Deprecated: Use ResolveActionFor. -func LookupActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { +func LookupActionFor[In, Out, StreamOut, StreamIn any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, StreamOut, StreamIn] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.LookupAction(key) if a == nil { return nil } - return a.(*Action[In, Out, Stream, Init]) + return a.(*Action[In, Out, StreamOut, StreamIn]) } // wrapBidiAsStreaming wraps a BidiFunc into a StreamingFunc for use with Run/RunJSON. -// The input is sent as a single message, and stream chunks are forwarded to the callback. -func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, Init]) StreamingFunc[In, Out, Stream] { - return func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { - inCh := make(chan In, 1) - outCh := make(chan Stream, 1) +// The input is passed as the initial input to the bidi func, and the input stream +// channel is closed immediately (no streaming inputs). Outgoing stream chunks are +// forwarded to the callback. +func wrapBidiAsStreaming[In, Out, StreamOut, StreamIn any](fn BidiFunc[In, Out, StreamOut, StreamIn]) StreamingFunc[In, Out, StreamOut] { + return func(ctx context.Context, input In, cb StreamCallback[StreamOut]) (Out, error) { + inCh := make(chan StreamIn, 1) + outCh := make(chan StreamOut, 1) doneCh := make(chan struct{}) var output Out @@ -485,12 +490,10 @@ func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, go func() { defer close(doneCh) defer close(outCh) - var init Init - output, fnErr = fn(ctx, init, inCh, outCh) + output, fnErr = fn(ctx, input, inCh, outCh) }() - // Send the single input and close. - inCh <- input + // No streaming inputs when used as a non-bidi streaming action. close(inCh) // Forward streamed chunks to the callback. @@ -512,9 +515,9 @@ func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, } // BidiConnection represents an active bidirectional streaming session. -type BidiConnection[In, Out, Stream any] struct { - inputCh chan In - streamCh chan Stream +type BidiConnection[StreamIn, Out, StreamOut any] struct { + inputCh chan StreamIn + streamCh chan StreamOut doneCh chan struct{} output Out err error @@ -526,7 +529,7 @@ type BidiConnection[In, Out, Stream any] struct { // Send sends an input message to the bidi action. // Returns an error if the connection is closed or the context is cancelled. -func (c *BidiConnection[In, Out, Stream]) Send(input In) (err error) { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Send(input StreamIn) (err error) { defer func() { if r := recover(); r != nil { err = NewError(FAILED_PRECONDITION, "connection is closed") @@ -544,7 +547,7 @@ func (c *BidiConnection[In, Out, Stream]) Send(input In) (err error) { } // Close signals that no more inputs will be sent. -func (c *BidiConnection[In, Out, Stream]) Close() error { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { @@ -557,8 +560,8 @@ func (c *BidiConnection[In, Out, Stream]) Close() error { // Receive returns an iterator for receiving streamed response chunks. // The iterator completes when the action finishes. -func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { - return func(yield func(Stream, error) bool) { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Receive() iter.Seq2[StreamOut, error] { + return func(yield func(StreamOut, error) bool) { for { select { case chunk, ok := <-c.streamCh: @@ -570,7 +573,7 @@ func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { return } case <-c.ctx.Done(): - var zero Stream + var zero StreamOut yield(zero, c.ctx.Err()) return } @@ -580,7 +583,7 @@ func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { // Output returns the final output after the action completes. // Blocks until done or context cancelled. -func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Output() (Out, error) { select { case <-c.doneCh: c.mu.Lock() @@ -593,6 +596,6 @@ func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) { } // Done returns a channel that is closed when the connection completes. -func (c *BidiConnection[In, Out, Stream]) Done() <-chan struct{} { +func (c *BidiConnection[StreamIn, Out, StreamOut]) Done() <-chan struct{} { return c.doneCh } diff --git a/go/core/api/action.go b/go/core/api/action.go index a38958af51..336e6145ce 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -74,7 +74,7 @@ type ActionDesc struct { Description string `json:"description"` // Description of the action. InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. - StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema to validate against the action's streamed chunks. - InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema to validate against the action's initialization data. + StreamOutSchema map[string]any `json:"streamOutSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. + StreamInSchema map[string]any `json:"streamInSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). Metadata map[string]any `json:"metadata"` // Metadata for the action. } diff --git a/go/core/flow.go b/go/core/flow.go index 3ad8bba60d..e10fb34344 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -26,17 +26,17 @@ import ( "github.com/firebase/genkit/go/internal/base" ) -// A Flow is a user-defined Action. A Flow[In, Out, Stream, Init] represents a function from In to Out. -// The Stream parameter is for flows that support streaming: providing their results incrementally. The Init parameter is for bidi flows. -type Flow[In, Out, Stream, Init any] struct { - *Action[In, Out, Stream, Init] +// A Flow is a user-defined Action. A Flow[In, Out, StreamOut, StreamIn] represents a function from In to Out. +// The StreamOut parameter is for flows that support streaming: providing their results incrementally. The StreamIn parameter is for bidi flows. +type Flow[In, Out, StreamOut, StreamIn any] struct { + *Action[In, Out, StreamOut, StreamIn] } // StreamingFlowValue is either a streamed value or a final output of a flow. -type StreamingFlowValue[Out, Stream any] struct { +type StreamingFlowValue[Out, StreamOut any] struct { Done bool - Output Out // valid if Done is true - Stream Stream // valid if Done is false + Output Out // valid if Done is true + Stream StreamOut // valid if Done is false } // flowContextKey is a context key that indicates whether the current context is a flow context. @@ -59,14 +59,14 @@ func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{} } // NewStreamingFlow creates a streaming Flow that runs fn without registering it. -func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { - return &Flow[In, Out, Stream, struct{}]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { +func NewStreamingFlow[In, Out, StreamOut any](name string, fn StreamingFunc[In, Out, StreamOut]) *Flow[In, Out, StreamOut, struct{}] { + return &Flow[In, Out, StreamOut, struct{}]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, StreamOut) error) (Out, error) { fc := &flowContext{ flowName: name, } ctx = flowContextKey.NewContext(ctx, fc) if cb == nil { - cb = func(context.Context, Stream) error { return nil } + cb = func(context.Context, StreamOut) error { return nil } } return fn(ctx, input, cb) })} @@ -74,12 +74,12 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out // NewBidiFlow creates a bidirectional streaming Flow without registering it. // Flow context is injected so that [Run] works inside the bidi function. -func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { - wrapped := func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) { +func NewBidiFlow[In, Out, StreamOut, StreamIn any](name string, fn BidiFunc[In, Out, StreamOut, StreamIn]) *Flow[In, Out, StreamOut, StreamIn] { + wrapped := func(ctx context.Context, in In, inCh <-chan StreamIn, outCh chan<- StreamOut) (Out, error) { ctx = flowContextKey.NewContext(ctx, &flowContext{flowName: name}) - return fn(ctx, init, inCh, outCh) + return fn(ctx, in, inCh, outCh) } - return &Flow[In, Out, Stream, Init]{NewBidiAction(name, api.ActionTypeFlow, nil, wrapped)} + return &Flow[In, Out, StreamOut, StreamIn]{NewBidiAction(name, api.ActionTypeFlow, nil, wrapped)} } // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. @@ -92,13 +92,13 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. // // fn takes an input of type In and returns an output of type Out, optionally -// streaming values of type Stream incrementally by invoking a callback. +// streaming values of type StreamOut incrementally by invoking a callback. // // If the function supports streaming and the callback is non-nil, it should // stream the results by invoking the callback periodically, ultimately returning // with a final return value that includes all the streamed data. // Otherwise, it should ignore the callback and just return a result. -func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { +func DefineStreamingFlow[In, Out, StreamOut any](r api.Registry, name string, fn StreamingFunc[In, Out, StreamOut]) *Flow[In, Out, StreamOut, struct{}] { f := NewStreamingFlow(name, fn) f.Register(r) return f @@ -106,7 +106,7 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St // DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. // Flow context is injected so that [Run] works inside the bidi function. -func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { +func DefineBidiFlow[In, Out, StreamOut, StreamIn any](r api.Registry, name string, fn BidiFunc[In, Out, StreamOut, StreamIn]) *Flow[In, Out, StreamOut, StreamIn] { f := NewBidiFlow(name, fn) f.Register(r) return f @@ -140,7 +140,7 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out } // Run runs the flow in the context of another flow. -func (f *Flow[In, Out, Stream, Init]) Run(ctx context.Context, input In) (Out, error) { +func (f *Flow[In, Out, StreamOut, StreamIn]) Run(ctx context.Context, input In) (Out, error) { return f.Action.Run(ctx, input, nil) } @@ -156,13 +156,13 @@ func (f *Flow[In, Out, Stream, Init]) Run(ctx context.Context, input In) (Out, e // again. // // Otherwise the Stream field of the passed [StreamingFlowValue] holds a streamed result. -func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, Stream], error) bool) { - return func(yield func(*StreamingFlowValue[Out, Stream], error) bool) { - cb := func(ctx context.Context, s Stream) error { +func (f *Flow[In, Out, StreamOut, StreamIn]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, StreamOut], error) bool) { + return func(yield func(*StreamingFlowValue[Out, StreamOut], error) bool) { + cb := func(ctx context.Context, s StreamOut) error { if ctx.Err() != nil { return ctx.Err() } - if !yield(&StreamingFlowValue[Out, Stream]{Stream: s}, nil) { + if !yield(&StreamingFlowValue[Out, StreamOut]{Stream: s}, nil) { return errStop } return nil @@ -175,7 +175,7 @@ func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func if err != nil { yield(nil, err) } else { - yield(&StreamingFlowValue[Out, Stream]{Done: true, Output: output}, nil) + yield(&StreamingFlowValue[Out, StreamOut]{Done: true, Output: output}, nil) } } } diff --git a/go/core/flow_test.go b/go/core/flow_test.go index 7da8d31778..5810ac4625 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -264,6 +264,7 @@ func TestFlowNameFromContextOutsideFlow(t *testing.T) { func TestBidiActionEcho(t *testing.T) { ctx := context.Background() + // In=struct{} (no initial data), Out=string, OutStream=string, InStream=string action := NewBidiAction( "echo", api.ActionTypeCustom, nil, func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { @@ -315,18 +316,19 @@ func TestBidiActionEcho(t *testing.T) { } } -func TestBidiActionWithInit(t *testing.T) { +func TestBidiActionWithConfig(t *testing.T) { ctx := context.Background() type Config struct { Prefix string } + // In=Config (initial config), Out=string, OutStream=string, InStream=string action := NewBidiAction( "prefixed", api.ActionTypeCustom, nil, - func(ctx context.Context, init Config, inCh <-chan string, outCh chan<- string) (string, error) { + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { for input := range inCh { - outCh <- fmt.Sprintf("%s: %s", init.Prefix, input) + outCh <- fmt.Sprintf("%s: %s", cfg.Prefix, input) } return "done", nil }, diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 8adc19f69e..ed8be96776 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -355,14 +355,15 @@ func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.St // DefineBidiFlow defines a bidirectional streaming flow, registers it as a [core.Action] of type Flow, // and returns a [core.Flow] capable of bidirectional streaming. // -// The provided function `fn` receives initialization data of type `Init`, reads -// inputs of type `In` from an input channel, and writes streamed outputs of type -// `Stream` to an output channel. It returns a final output of type `Out` when complete. +// The provided function `fn` receives an initial input of type `In`, reads +// incoming stream messages of type `StreamIn` from an input channel, and writes +// outgoing stream messages of type `StreamOut` to an output channel. It returns +// a final output of type `Out` when complete. // // Example: // // chatFlow := genkit.DefineBidiFlow(g, "chat", -// func(ctx context.Context, init struct{}, inCh <-chan string, outCh chan<- string) (string, error) { +// func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { // var count int // for input := range inCh { // count++ @@ -396,7 +397,7 @@ func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.St // // Get the final output: // output, err := conn.Output() // fmt.Println(output) // Output: "processed 2 messages" -func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.BidiFunc[In, Out, Stream, Init]) *core.Flow[In, Out, Stream, Init] { +func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn core.BidiFunc[In, Out, StreamOut, StreamIn]) *core.Flow[In, Out, StreamOut, StreamIn] { return core.DefineBidiFlow(g.reg, name, fn) } From ec9d05d5e0a5b1f02425f93f8009c9507a880e26 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Mar 2026 10:34:30 -0700 Subject: [PATCH 051/141] refactored order of type params --- go/ai/exp/session_flow.go | 2 +- go/core/action.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go/ai/exp/session_flow.go b/go/ai/exp/session_flow.go index e4f507e5a3..ae61403931 100644 --- a/go/ai/exp/session_flow.go +++ b/go/ai/exp/session_flow.go @@ -575,7 +575,7 @@ func newSessionFromInit[State any]( // It provides a Receive() iterator that supports multi-turn patterns: breaking out // of the iterator between turns does not cancel the underlying connection. type SessionFlowConnection[Stream, State any] struct { - conn *core.BidiConnection[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream]] + conn *core.BidiConnection[*SessionFlowInput, *SessionFlowStreamChunk[Stream], *SessionFlowOutput[State]] // chunks buffers stream chunks from the underlying connection so that // breaking from Receive() between turns doesn't cancel the context. diff --git a/go/core/action.go b/go/core/action.go index 4f8de572ca..f0cc1a6d63 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -397,13 +397,13 @@ func (a *Action[In, Out, StreamOut, StreamIn]) Register(r api.Registry) { // StreamBidi starts a bidirectional streaming connection. // Returns an error if the action is not a bidi action. // A trace span is created that remains open for the lifetime of the connection. -func (a *Action[In, Out, StreamOut, StreamIn]) StreamBidi(ctx context.Context, in In) (*BidiConnection[StreamIn, Out, StreamOut], error) { +func (a *Action[In, Out, StreamOut, StreamIn]) StreamBidi(ctx context.Context, in In) (*BidiConnection[StreamIn, StreamOut, Out], error) { if a.bidiFn == nil { return nil, NewError(FAILED_PRECONDITION, "StreamBidi called on non-bidi action %q", a.desc.Name) } ctx, cancel := context.WithCancel(ctx) - conn := &BidiConnection[StreamIn, Out, StreamOut]{ + conn := &BidiConnection[StreamIn, StreamOut, Out]{ inputCh: make(chan StreamIn, 1), streamCh: make(chan StreamOut, 1), doneCh: make(chan struct{}), @@ -515,7 +515,7 @@ func wrapBidiAsStreaming[In, Out, StreamOut, StreamIn any](fn BidiFunc[In, Out, } // BidiConnection represents an active bidirectional streaming session. -type BidiConnection[StreamIn, Out, StreamOut any] struct { +type BidiConnection[StreamIn, StreamOut, Out any] struct { inputCh chan StreamIn streamCh chan StreamOut doneCh chan struct{} @@ -529,7 +529,7 @@ type BidiConnection[StreamIn, Out, StreamOut any] struct { // Send sends an input message to the bidi action. // Returns an error if the connection is closed or the context is cancelled. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Send(input StreamIn) (err error) { +func (c *BidiConnection[StreamIn, StreamOut, Out]) Send(input StreamIn) (err error) { defer func() { if r := recover(); r != nil { err = NewError(FAILED_PRECONDITION, "connection is closed") @@ -547,7 +547,7 @@ func (c *BidiConnection[StreamIn, Out, StreamOut]) Send(input StreamIn) (err err } // Close signals that no more inputs will be sent. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Close() error { +func (c *BidiConnection[StreamIn, StreamOut, Out]) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { @@ -560,7 +560,7 @@ func (c *BidiConnection[StreamIn, Out, StreamOut]) Close() error { // Receive returns an iterator for receiving streamed response chunks. // The iterator completes when the action finishes. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Receive() iter.Seq2[StreamOut, error] { +func (c *BidiConnection[StreamIn, StreamOut, Out]) Receive() iter.Seq2[StreamOut, error] { return func(yield func(StreamOut, error) bool) { for { select { @@ -583,7 +583,7 @@ func (c *BidiConnection[StreamIn, Out, StreamOut]) Receive() iter.Seq2[StreamOut // Output returns the final output after the action completes. // Blocks until done or context cancelled. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Output() (Out, error) { +func (c *BidiConnection[StreamIn, StreamOut, Out]) Output() (Out, error) { select { case <-c.doneCh: c.mu.Lock() @@ -596,6 +596,6 @@ func (c *BidiConnection[StreamIn, Out, StreamOut]) Output() (Out, error) { } // Done returns a channel that is closed when the connection completes. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Done() <-chan struct{} { +func (c *BidiConnection[StreamIn, StreamOut, Out]) Done() <-chan struct{} { return c.doneCh } From cf004557d7b7bf900fb2dc1a24e311e96f5714cb Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Mar 2026 10:36:30 -0700 Subject: [PATCH 052/141] Update action.go --- go/core/api/action.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index 336e6145ce..d0df3ee613 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -68,13 +68,13 @@ const ( // ActionDesc is a descriptor of an action. type ActionDesc struct { - Type ActionType `json:"type"` // Type of the action. - Key string `json:"key"` // Key of the action. - Name string `json:"name"` // Name of the action. - Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. + Type ActionType `json:"type"` // Type of the action. + Key string `json:"key"` // Key of the action. + Name string `json:"name"` // Name of the action. + Description string `json:"description"` // Description of the action. + InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. StreamOutSchema map[string]any `json:"streamOutSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. StreamInSchema map[string]any `json:"streamInSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). - Metadata map[string]any `json:"metadata"` // Metadata for the action. + Metadata map[string]any `json:"metadata"` // Metadata for the action. } From 00a2f8869d751d83d65540f27e55f363aeeef821 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Mar 2026 10:36:30 -0700 Subject: [PATCH 053/141] Update action.go --- go/core/api/action.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index 336e6145ce..d0df3ee613 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -68,13 +68,13 @@ const ( // ActionDesc is a descriptor of an action. type ActionDesc struct { - Type ActionType `json:"type"` // Type of the action. - Key string `json:"key"` // Key of the action. - Name string `json:"name"` // Name of the action. - Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. + Type ActionType `json:"type"` // Type of the action. + Key string `json:"key"` // Key of the action. + Name string `json:"name"` // Name of the action. + Description string `json:"description"` // Description of the action. + InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. StreamOutSchema map[string]any `json:"streamOutSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. StreamInSchema map[string]any `json:"streamInSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). - Metadata map[string]any `json:"metadata"` // Metadata for the action. + Metadata map[string]any `json:"metadata"` // Metadata for the action. } From 8867ea0132bd1a2e55121ad064b6b9919e98fd67 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Mar 2026 10:37:01 -0700 Subject: [PATCH 054/141] Update action.go --- go/core/api/action.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index d0cc605d18..0818c527e6 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -69,13 +69,13 @@ const ( // ActionDesc is a descriptor of an action. type ActionDesc struct { - Type ActionType `json:"type"` // Type of the action. - Key string `json:"key"` // Key of the action. - Name string `json:"name"` // Name of the action. - Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. + Type ActionType `json:"type"` // Type of the action. + Key string `json:"key"` // Key of the action. + Name string `json:"name"` // Name of the action. + Description string `json:"description"` // Description of the action. + InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. StreamOutSchema map[string]any `json:"streamOutSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. StreamInSchema map[string]any `json:"streamInSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). - Metadata map[string]any `json:"metadata"` // Metadata for the action. + Metadata map[string]any `json:"metadata"` // Metadata for the action. } From e95a360cbdb20f0253f3a720595c07ced905de3d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 30 Mar 2026 14:23:02 -0700 Subject: [PATCH 055/141] Update session_flow_test.go --- go/ai/exp/session_flow_test.go | 133 ++++++++++++++++++++++----------- 1 file changed, 90 insertions(+), 43 deletions(-) diff --git a/go/ai/exp/session_flow_test.go b/go/ai/exp/session_flow_test.go index 1d8d0ddcb5..ea5ba01e81 100644 --- a/go/ai/exp/session_flow_test.go +++ b/go/ai/exp/session_flow_test.go @@ -1224,7 +1224,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { reg := registry.New() ai.ConfigureFormats(reg) - // Define a tool that the model will call. + // Define two tools so the model can call them across multiple rounds. ai.DefineTool(reg, "greet", "returns a greeting", func(ctx *ai.ToolContext, input struct { Name string `json:"name"` @@ -1232,38 +1232,66 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { return "hello " + input.Name, nil }, ) + ai.DefineTool(reg, "farewell", "returns a farewell", + func(ctx *ai.ToolContext, input struct { + Name string `json:"name"` + }) (string, error) { + return "goodbye " + input.Name, nil + }, + ) - // Model that requests a tool call on the first call, then returns - // a final text response once it sees the tool result. + // Model that drives a multi-round tool loop: + // Round 1: request "greet" tool + // Round 2: after seeing greet response, request "farewell" tool + // Round 3: after seeing farewell response, return final text ai.DefineModel(reg, "test/toolmodel", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true, Tools: true}}, func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { - // Check if we already got a tool response. + // Count tool responses to determine which round we're in. + toolResps := 0 for _, msg := range req.Messages { for _, p := range msg.Content { if p.IsToolResponse() { - resp := &ai.ModelResponse{ - Request: req, - Message: ai.NewModelTextMessage("done: " + fmt.Sprintf("%v", p.ToolResponse.Output)), - } - if cb != nil { - cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) - } - return resp, nil + toolResps++ } } } - // First call: request the tool. - resp := &ai.ModelResponse{ - Request: req, - Message: &ai.Message{ - Role: ai.RoleModel, - Content: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{ - Name: "greet", - Input: map[string]any{"name": "world"}, - })}, - }, + + switch toolResps { + case 0: + // Round 1: request greet. + return &ai.ModelResponse{ + Request: req, + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "greet", + Input: map[string]any{"name": "world"}, + })}, + }, + }, nil + case 1: + // Round 2: saw greet response, now request farewell. + return &ai.ModelResponse{ + Request: req, + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "farewell", + Input: map[string]any{"name": "world"}, + })}, + }, + }, nil + default: + // Round 3: saw both tool responses, return final text. + resp := &ai.ModelResponse{ + Request: req, + Message: ai.NewModelTextMessage("done"), + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) + } + return resp, nil } - return resp, nil }, ) ai.DefineGenerateAction(ctx, reg) @@ -1271,7 +1299,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ai.DefinePrompt(reg, "toolPrompt", ai.WithModelName("test/toolmodel"), ai.WithSystem("You are a test assistant."), - ai.WithTools(ai.ToolName("greet")), + ai.WithTools(ai.ToolName("greet"), ai.ToolName("farewell")), ) af := DefineSessionFlowFromPrompt[testState, any](reg, "toolPrompt", nil) @@ -1297,14 +1325,16 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // Session should contain: + // Session should contain all messages from the multi-round tool loop: // 1. user message ("go") - // 2. model tool-call message - // 3. tool response message - // 4. final model text response + // 2. model tool-call message (greet request) + // 3. tool response message (greet result) + // 4. model tool-call message (farewell request) + // 5. tool response message (farewell result) + // 6. final model text response msgs := response.State.Messages - if got := len(msgs); got != 4 { - t.Errorf("expected 4 messages, got %d", got) + if got := len(msgs); got != 6 { + t.Errorf("expected 6 messages, got %d", got) for i, m := range msgs { t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) } @@ -1314,21 +1344,38 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { if msgs[0].Role != ai.RoleUser { t.Errorf("msg[0] role = %s, want user", msgs[0].Role) } - hasToolReq := false - for _, p := range msgs[1].Content { - if p.IsToolRequest() { - hasToolReq = true - break + + // Verify the two tool request/response pairs. + for _, pair := range []struct { + reqIdx int + respIdx int + tool string + }{ + {1, 2, "greet"}, + {3, 4, "farewell"}, + } { + reqMsg := msgs[pair.reqIdx] + if reqMsg.Role != ai.RoleModel { + t.Errorf("msg[%d] role = %s, want model", pair.reqIdx, reqMsg.Role) + } + hasReq := false + for _, p := range reqMsg.Content { + if p.IsToolRequest() && p.ToolRequest.Name == pair.tool { + hasReq = true + } + } + if !hasReq { + t.Errorf("msg[%d] should contain a %s tool request", pair.reqIdx, pair.tool) + } + + respMsg := msgs[pair.respIdx] + if respMsg.Role != ai.RoleTool { + t.Errorf("msg[%d] role = %s, want tool", pair.respIdx, respMsg.Role) } } - if msgs[1].Role != ai.RoleModel || !hasToolReq { - t.Errorf("msg[1] should be a model tool-call message") - } - if msgs[2].Role != ai.RoleTool { - t.Errorf("msg[2] role = %s, want tool", msgs[2].Role) - } - if msgs[3].Role != ai.RoleModel || !strings.Contains(msgs[3].Text(), "done:") { - t.Errorf("msg[3] should be final model response, got role=%s text=%s", msgs[3].Role, msgs[3].Text()) + + if msgs[5].Role != ai.RoleModel || msgs[5].Text() != "done" { + t.Errorf("msg[5] should be final model response, got role=%s text=%q", msgs[5].Role, msgs[5].Text()) } } From bf7b4ef4d73567929ae94f2a95443be5c6a0a3b0 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 17 Apr 2026 07:46:47 -0700 Subject: [PATCH 056/141] refactored EndTurn bool into TurnEnd struct --- genkit-tools/common/src/types/agent.ts | 27 ++++++- genkit-tools/genkit-schema.json | 12 ++- go/ai/exp/gen.go | 42 ++++++---- go/ai/exp/session_flow.go | 17 ++-- go/ai/exp/session_flow_test.go | 83 ++++++++++---------- go/core/schemas.config | 27 +++++-- go/genkit/genkit.go | 4 +- go/samples/custom-agent/main.go | 8 +- go/samples/prompt-agent/main.go | 8 +- py/packages/genkit/src/genkit/core/typing.py | 10 ++- 10 files changed, 146 insertions(+), 92 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 8046fa468c..8835fc140f 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -100,6 +100,23 @@ export const SessionFlowOutputSchema = z.object({ }); export type SessionFlowOutput = z.infer; +/** + * Zod schema for the turn-end signal emitted by a session flow. + * + * A TurnEnd value is emitted exactly once per turn, regardless of whether a + * snapshot was persisted. Grouping all turn-end signals here lets callers + * detect turn boundaries with a single field check and leaves room for + * additional turn-end metadata in the future. + */ +export const TurnEndSchema = z.object({ + /** + * ID of the snapshot persisted at the end of this turn. Empty if no + * snapshot was created (callback returned false or no store configured). + */ + snapshotId: z.string().optional(), +}); +export type TurnEnd = z.infer; + /** * Zod schema for session flow stream chunk. */ @@ -110,10 +127,12 @@ export const SessionFlowStreamChunkSchema = z.object({ status: z.any().optional(), /** A newly produced artifact. */ artifact: ArtifactSchema.optional(), - /** ID of a snapshot that was just persisted. */ - snapshotId: z.string().optional(), - /** Signals that the session flow has finished processing the current input. */ - endTurn: z.boolean().optional(), + /** + * Non-null when the session flow has finished processing the current + * input. Groups all turn-end signals; the client should stop iterating and + * may send the next input. + */ + turnEnd: TurnEndSchema.optional(), }); export type SessionFlowStreamChunk = z.infer< typeof SessionFlowStreamChunkSchema diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index c5d0f7377c..56fd2ad3e8 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -99,11 +99,17 @@ "artifact": { "$ref": "#/$defs/Artifact" }, + "turnEnd": { + "$ref": "#/$defs/TurnEnd" + } + }, + "additionalProperties": false + }, + "TurnEnd": { + "type": "object", + "properties": { "snapshotId": { "type": "string" - }, - "endTurn": { - "type": "boolean" } }, "additionalProperties": false diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index b27ace232f..f2344605cc 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -22,6 +22,17 @@ import ( "github.com/firebase/genkit/go/ai" ) +// Artifact represents a named collection of parts produced during a session. +// Examples: generated files, images, code snippets, diagrams, etc. +type Artifact struct { + // Metadata contains additional artifact-specific data. + Metadata map[string]any `json:"metadata,omitempty"` + // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). + Name string `json:"name,omitempty"` + // Parts contains the artifact content (text, media, etc.). + Parts []*ai.Part `json:"parts"` +} + // SessionFlowInit is the input for starting an session flow invocation. // Provide either SnapshotID (to load from store) or State (direct state). type SessionFlowInit[State any] struct { @@ -73,27 +84,16 @@ type SessionFlowResult struct { type SessionFlowStreamChunk[Stream any] struct { // Artifact contains a newly produced artifact. Artifact *Artifact `json:"artifact,omitempty"` - // EndTurn signals that the session flow has finished processing the current input. - // When true, the client should stop iterating and may send the next input. - EndTurn bool `json:"endTurn,omitempty"` // ModelChunk contains generation tokens from the model. ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` - // SnapshotID contains the ID of a snapshot that was just persisted. - SnapshotID string `json:"snapshotId,omitempty"` // Status contains user-defined structured status information. // The Stream type parameter defines the shape of this data. Status Stream `json:"status,omitempty"` -} - -// Artifact represents a named collection of parts produced during a session. -// Examples: generated files, images, code snippets, diagrams, etc. -type Artifact struct { - // Metadata contains additional artifact-specific data. - Metadata map[string]any `json:"metadata,omitempty"` - // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). - Name string `json:"name,omitempty"` - // Parts contains the artifact content (text, media, etc.). - Parts []*ai.Part `json:"parts"` + // TurnEnd is non-nil when the session flow has finished processing the current + // input. It groups all turn-end signals (snapshot ID, etc.) so callers can + // check a single field. When set, the client should stop iterating and may + // send the next input. + TurnEnd *TurnEnd `json:"turnEnd,omitempty"` } // SessionState is the portable conversation state that flows between client @@ -120,3 +120,13 @@ const ( // InvocationEnd indicates the snapshot was triggered at the end of the invocation. SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" ) + +// TurnEnd groups the signals emitted when a session flow turn finishes. +// A TurnEnd value is emitted exactly once per turn, regardless of whether a +// snapshot was persisted. +type TurnEnd struct { + // SnapshotID is the ID of the snapshot persisted at the end of this turn. + // Empty if no snapshot was created (callback returned false or no store + // configured). + SnapshotID string `json:"snapshotId,omitempty"` +} diff --git a/go/ai/exp/session_flow.go b/go/ai/exp/session_flow.go index ae61403931..7f571afbf1 100644 --- a/go/ai/exp/session_flow.go +++ b/go/ai/exp/session_flow.go @@ -63,8 +63,8 @@ type SessionRunner[State any] struct { // Run loops over the input channel, calling fn for each turn. Each turn is // wrapped in a trace span for observability. Input messages are automatically -// added to the session before fn is called. After fn returns successfully, an -// EndTurn chunk is sent and a snapshot check is triggered. +// added to the session before fn is called. After fn returns successfully, a +// TurnEnd chunk is sent and a snapshot check is triggered. func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error { for input := range a.InputCh { spanMeta := &tracing.SpanMetadata{ @@ -281,8 +281,8 @@ func DefineSessionFlow[Stream, State any]( if chunk.Artifact != nil { session.AddArtifacts(chunk.Artifact) } - // Accumulate content chunks (exclude control signals from onEndTurn). - if !chunk.EndTurn && chunk.SnapshotID == "" { + // Accumulate content chunks (exclude the TurnEnd control signal). + if chunk.TurnEnd == nil { turnMu.Lock() turnChunks = append(turnChunks, chunk) turnMu.Unlock() @@ -291,14 +291,13 @@ func DefineSessionFlow[Stream, State any]( } }() - // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. + // Wire up onEndTurn: triggers snapshot + sends TurnEnd chunk. // Writes through respCh to preserve ordering with user chunks. agentSess.onEndTurn = func(turnCtx context.Context) { snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) - if snapshotID != "" { - respCh <- &SessionFlowStreamChunk[Stream]{SnapshotID: snapshotID} + respCh <- &SessionFlowStreamChunk[Stream]{ + TurnEnd: &TurnEnd{SnapshotID: snapshotID}, } - respCh <- &SessionFlowStreamChunk[Stream]{EndTurn: true} } result, fnErr := fn(ctx, Responder[Stream](respCh), agentSess) @@ -634,7 +633,7 @@ func (c *SessionFlowConnection[Stream, State]) Close() error { // Receive returns an iterator for receiving stream chunks. // Unlike the underlying BidiConnection.Receive, breaking out of this iterator // does not cancel the connection. This enables multi-turn patterns where the -// caller breaks on EndTurn, sends the next input, then calls Receive again. +// caller breaks on TurnEnd, sends the next input, then calls Receive again. func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { c.initReceiver() return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { diff --git a/go/ai/exp/session_flow_test.go b/go/ai/exp/session_flow_test.go index ea5ba01e81..fb80555fb8 100644 --- a/go/ai/exp/session_flow_test.go +++ b/go/ai/exp/session_flow_test.go @@ -78,11 +78,11 @@ func TestSessionFlow_BasicMultiTurn(t *testing.T) { t.Fatalf("Receive error: %v", err) } turn1Chunks++ - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } - if turn1Chunks < 2 { // at least status + endTurn + if turn1Chunks < 2 { // at least status + TurnEnd t.Errorf("expected at least 2 chunks in turn 1, got %d", turn1Chunks) } @@ -94,7 +94,7 @@ func TestSessionFlow_BasicMultiTurn(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -148,10 +148,10 @@ func TestSessionFlow_WithSessionStore(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) + } break } } @@ -216,7 +216,7 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -239,7 +239,7 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -313,7 +313,7 @@ func TestSessionFlow_ClientManagedState(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -387,7 +387,7 @@ func TestSessionFlow_Artifacts(t *testing.T) { if chunk.Artifact != nil { receivedArtifacts = append(receivedArtifacts, chunk.Artifact) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -445,10 +445,10 @@ func TestSessionFlow_SnapshotCallback(t *testing.T) { if err != nil { t.Fatalf("Receive error on turn %d: %v", i, err) } - if chunk.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) + } break } } @@ -496,7 +496,7 @@ func TestSessionFlow_SendMessages(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -547,7 +547,7 @@ func TestSessionFlow_SessionContext(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -609,7 +609,7 @@ func TestSessionFlow_SetMessages(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -656,7 +656,7 @@ func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -767,7 +767,7 @@ func TestSessionFlow_TurnSpanOutput(t *testing.T) { if err != nil { t.Fatalf("Receive error on turn %d: %v", turn, err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -793,11 +793,8 @@ func TestSessionFlow_TurnSpanOutput(t *testing.T) { t.Errorf("turn %d: expected 3 chunks, got %d", i, len(chunks)) } for j, chunk := range chunks { - if chunk.EndTurn { - t.Errorf("turn %d, chunk %d: EndTurn should not be in turn output", i, j) - } - if chunk.SnapshotID != "" { - t.Errorf("turn %d, chunk %d: SnapshotID should not be in turn output", i, j) + if chunk.TurnEnd != nil { + t.Errorf("turn %d, chunk %d: TurnEnd should not be in turn output", i, j) } } } @@ -839,10 +836,10 @@ func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.SnapshotID != "" { - sawSnapshot = true - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + sawSnapshot = true + } break } } @@ -850,10 +847,10 @@ func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { conn.Output() if !sawSnapshot { - t.Fatal("expected a snapshot chunk on the stream") + t.Fatal("expected a snapshot ID on the turn-end chunk") } - // Turn output should contain only the status chunk, not the snapshot/endTurn. + // Turn output should contain only the status chunk, not the TurnEnd signal. if len(capturedOutputs) != 1 { t.Fatalf("expected 1 captured output, got %d", len(capturedOutputs)) } @@ -938,7 +935,7 @@ func TestPromptAgent_Basic(t *testing.T) { if chunk.ModelChunk != nil { gotChunk = true } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -954,7 +951,7 @@ func TestPromptAgent_Basic(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1007,7 +1004,7 @@ func TestPromptAgent_PromptInputOverride(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1085,7 +1082,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { if chunk.ModelChunk != nil { turn1Response += chunk.ModelChunk.Text() } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1106,7 +1103,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { if chunk.ModelChunk != nil { turn2Response += chunk.ModelChunk.Text() } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1160,7 +1157,7 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1195,7 +1192,7 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1314,7 +1311,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1633,10 +1630,10 @@ func TestSessionFlow_MultiTurnSnapshotDedup(t *testing.T) { if err != nil { t.Fatalf("Receive error on turn %d: %v", i, err) } - if chunk.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) + } break } } diff --git a/go/core/schemas.config b/go/core/schemas.config index e6be6e0f99..fa9976e1a0 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1268,13 +1268,30 @@ SessionFlowStreamChunk.artifact doc Artifact contains a newly produced artifact. . -SessionFlowStreamChunk.snapshotId doc -SnapshotID contains the ID of a snapshot that was just persisted. +SessionFlowStreamChunk.turnEnd type *TurnEnd +SessionFlowStreamChunk.turnEnd doc +TurnEnd is non-nil when the session flow has finished processing the current +input. It groups all turn-end signals (snapshot ID, etc.) so callers can +check a single field. When set, the client should stop iterating and may +send the next input. . -SessionFlowStreamChunk.endTurn doc -EndTurn signals that the session flow has finished processing the current input. -When true, the client should stop iterating and may send the next input. +# ---------------------------------------------------------------------------- +# TurnEnd +# ---------------------------------------------------------------------------- + +TurnEnd pkg ai/exp + +TurnEnd doc +TurnEnd groups the signals emitted when a session flow turn finishes. +A TurnEnd value is emitted exactly once per turn, regardless of whether a +snapshot was persisted. +. + +TurnEnd.snapshotId doc +SnapshotID is the ID of the snapshot persisted at the end of this turn. +Empty if no snapshot was created (callback returned false or no store +configured). . # ---------------------------------------------------------------------------- diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 01fbcdc3fe..f588a55f42 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -472,7 +472,7 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn // // Send a message and stream the response: // conn.SendText("Hello!") // for chunk, err := range conn.Receive() { -// if chunk.EndTurn { +// if chunk.TurnEnd != nil { // break // } // fmt.Print(chunk.ModelChunk.Text()) @@ -540,7 +540,7 @@ func DefineSessionFlow[Stream, State any]( // // Send a message and stream the response: // conn.SendText("Hello!") // for chunk, err := range conn.Receive() { -// if chunk.EndTurn { +// if chunk.TurnEnd != nil { // break // } // fmt.Print(chunk.ModelChunk.Text()) diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index b8e2274e0c..a05c908557 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -104,10 +104,10 @@ func main() { if chunk.ModelChunk != nil { fmt.Print(chunk.ModelChunk.Text()) } - if chunk.SnapshotID != "" { - fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) + } fmt.Println() fmt.Println() break diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index e46ff2067e..b501872520 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -84,10 +84,10 @@ func main() { if chunk.ModelChunk != nil { fmt.Print(chunk.ModelChunk.Text()) } - if chunk.SnapshotID != "" { - fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) + } fmt.Println() fmt.Println() break diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index b207718212..64197201a1 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -1006,6 +1006,13 @@ class SessionFlowResult(BaseModel): artifacts: list[Artifact] | None = None +class TurnEnd(BaseModel): + """Model for turnend data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str | None = Field(default=None) + + class SessionFlowStreamChunk(BaseModel): """Model for sessionflowstreamchunk data.""" @@ -1013,8 +1020,7 @@ class SessionFlowStreamChunk(BaseModel): model_chunk: ModelResponseChunk | None = Field(default=None) status: Any | None = None artifact: Artifact | None = None - snapshot_id: str | None = Field(default=None) - end_turn: bool | None = Field(default=None) + turn_end: TurnEnd | None = Field(default=None) class SessionState(BaseModel): From 9ef39c044dc818ee9b03fcdf907496cc15daf54f Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 17 Apr 2026 07:52:51 -0700 Subject: [PATCH 057/141] fix python --- py/packages/genkit/src/genkit/core/typing.py | 14 +- py/uv.lock | 557 +------------------ 2 files changed, 30 insertions(+), 541 deletions(-) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 64197201a1..403372e500 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -39,6 +39,13 @@ class Model(RootModel[Any]): root: Any +class TurnEnd(BaseModel): + """Model for turnend data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str | None = Field(default=None) + + class SnapshotEvent(StrEnum): """SnapshotEvent data type class.""" @@ -1006,13 +1013,6 @@ class SessionFlowResult(BaseModel): artifacts: list[Artifact] | None = None -class TurnEnd(BaseModel): - """Model for turnend data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str | None = Field(default=None) - - class SessionFlowStreamChunk(BaseModel): """Model for sessionflowstreamchunk data.""" diff --git a/py/uv.lock b/py/uv.lock index fa2c1d41a4..35ecf6bd45 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -554,20 +554,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/fa/123043af240e49752f1c4bd24da5053b6bd00cad78c2be53c0d1e8b975bc/backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34", size = 30181, upload-time = "2024-05-28T17:01:53.112Z" }, ] -[[package]] -name = "backrefs" -version = "6.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/86/e3/bb3a439d5cb255c4774724810ad8073830fac9c9dee123555820c1bcc806/backrefs-6.1.tar.gz", hash = "sha256:3bba1749aafe1db9b915f00e0dd166cba613b6f788ffd63060ac3485dc9be231", size = 7011962, upload-time = "2025-11-15T14:52:08.323Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/ee/c216d52f58ea75b5e1841022bbae24438b19834a29b163cb32aa3a2a7c6e/backrefs-6.1-py310-none-any.whl", hash = "sha256:2a2ccb96302337ce61ee4717ceacfbf26ba4efb1d55af86564b8bbaeda39cac1", size = 381059, upload-time = "2025-11-15T14:51:59.758Z" }, - { url = "https://files.pythonhosted.org/packages/e6/9a/8da246d988ded941da96c7ed945d63e94a445637eaad985a0ed88787cb89/backrefs-6.1-py311-none-any.whl", hash = "sha256:e82bba3875ee4430f4de4b6db19429a27275d95a5f3773c57e9e18abc23fd2b7", size = 392854, upload-time = "2025-11-15T14:52:01.194Z" }, - { url = "https://files.pythonhosted.org/packages/37/c9/fd117a6f9300c62bbc33bc337fd2b3c6bfe28b6e9701de336b52d7a797ad/backrefs-6.1-py312-none-any.whl", hash = "sha256:c64698c8d2269343d88947c0735cb4b78745bd3ba590e10313fbf3f78c34da5a", size = 398770, upload-time = "2025-11-15T14:52:02.584Z" }, - { url = "https://files.pythonhosted.org/packages/eb/95/7118e935b0b0bd3f94dfec2d852fd4e4f4f9757bdb49850519acd245cd3a/backrefs-6.1-py313-none-any.whl", hash = "sha256:4c9d3dc1e2e558965202c012304f33d4e0e477e1c103663fd2c3cc9bb18b0d05", size = 400726, upload-time = "2025-11-15T14:52:04.093Z" }, - { url = "https://files.pythonhosted.org/packages/1d/72/6296bad135bfafd3254ae3648cd152980a424bd6fed64a101af00cc7ba31/backrefs-6.1-py314-none-any.whl", hash = "sha256:13eafbc9ccd5222e9c1f0bec563e6d2a6d21514962f11e7fc79872fd56cbc853", size = 412584, upload-time = "2025-11-15T14:52:05.233Z" }, - { url = "https://files.pythonhosted.org/packages/02/e3/a4fa1946722c4c7b063cc25043a12d9ce9b4323777f89643be74cef2993c/backrefs-6.1-py39-none-any.whl", hash = "sha256:a9e99b8a4867852cad177a6430e31b0f6e495d65f8c6c134b68c14c3c95bf4b0", size = 381058, upload-time = "2025-11-15T14:52:06.698Z" }, -] - [[package]] name = "bandit" version = "1.9.3" @@ -735,15 +721,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/92/26d8d98de4c1676305e03ec2be67850afaf883b507bf71b917d852585ec8/bpython-0.26-py3-none-any.whl", hash = "sha256:91bdbbe667078677dc6b236493fc03e47a04cd099630a32ca3f72d6d49b71e20", size = 175988, upload-time = "2025-10-28T07:19:40.114Z" }, ] -[[package]] -name = "bracex" -version = "2.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/63/9a/fec38644694abfaaeca2798b58e276a8e61de49e2e37494ace423395febc/bracex-2.6.tar.gz", hash = "sha256:98f1347cd77e22ee8d967a30ad4e310b233f7754dbf31ff3fceb76145ba47dc7", size = 26642, upload-time = "2025-06-22T19:12:31.254Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/2a/9186535ce58db529927f6cf5990a849aa9e052eea3e2cfefe20b9e1802da/bracex-2.6-py3-none-any.whl", hash = "sha256:0b0049264e7340b3ec782b5cb99beb325f36c3782a32e36e876452fd49a09952", size = 11508, upload-time = "2025-06-22T19:12:29.781Z" }, -] - [[package]] name = "cachecontrol" version = "0.14.4" @@ -1607,15 +1584,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/09/d09dfaa2110884284be6006b7586ea519f7391de58ed5428f2bf457bcd03/dotpromptz_handlebars-0.1.8-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f23498821610d443a67c860922aba00d20bdd80b8421bfef0ceff07b713f8198", size = 666257, upload-time = "2026-01-30T06:44:46.929Z" }, ] -[[package]] -name = "editorconfig" -version = "0.17.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/88/3a/a61d9a1f319a186b05d14df17daea42fcddea63c213bcd61a929fb3a6796/editorconfig-0.17.1.tar.gz", hash = "sha256:23c08b00e8e08cc3adcddb825251c497478df1dada6aefeb01e626ad37303745", size = 14695, upload-time = "2025-06-09T08:21:37.097Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/fd/a40c621ff207f3ce8e484aa0fc8ba4eb6e3ecf52e15b42ba764b457a9550/editorconfig-0.17.1-py3-none-any.whl", hash = "sha256:1eda9c2c0db8c16dbd50111b710572a5e6de934e39772de1959d41f64fc17c82", size = 16360, upload-time = "2025-06-09T08:21:35.654Z" }, -] - [[package]] name = "email-validator" version = "2.3.0" @@ -2222,7 +2190,7 @@ wheels = [ [[package]] name = "genkit" -version = "0.5.0" +version = "0.5.1" source = { editable = "packages/genkit" } dependencies = [ { name = "anyio" }, @@ -2306,7 +2274,7 @@ provides-extras = ["dev-local-vectorstore", "flask", "google-cloud", "google-gen [[package]] name = "genkit-plugin-amazon-bedrock" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/amazon-bedrock" } dependencies = [ { name = "aioboto3" }, @@ -2333,7 +2301,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-anthropic" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/anthropic" } dependencies = [ { name = "anthropic" }, @@ -2348,7 +2316,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-checks" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/checks" } dependencies = [ { name = "genkit" }, @@ -2369,7 +2337,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-cloudflare-workers-ai" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/cloudflare-workers-ai" } dependencies = [ { name = "genkit" }, @@ -2394,7 +2362,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-cohere" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/cohere" } dependencies = [ { name = "cohere" }, @@ -2409,7 +2377,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-compat-oai" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/compat-oai" } dependencies = [ { name = "genkit" }, @@ -2426,7 +2394,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-deepseek" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/deepseek" } dependencies = [ { name = "genkit" }, @@ -2443,7 +2411,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-dev-local-vectorstore" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/dev-local-vectorstore" } dependencies = [ { name = "aiofiles" }, @@ -2464,7 +2432,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-evaluators" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/evaluators" } dependencies = [ { name = "aiofiles" }, @@ -2485,7 +2453,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-fastapi" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/fastapi" } dependencies = [ { name = "fastapi" }, @@ -2502,7 +2470,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-firebase" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/firebase" } dependencies = [ { name = "genkit" }, @@ -2528,7 +2496,7 @@ provides-extras = ["telemetry"] [[package]] name = "genkit-plugin-flask" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/flask" } dependencies = [ { name = "flask" }, @@ -2547,7 +2515,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-google-cloud" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/google-cloud" } dependencies = [ { name = "genkit" }, @@ -2568,7 +2536,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-google-genai" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/google-genai" } dependencies = [ { name = "genkit" }, @@ -2589,7 +2557,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-huggingface" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/huggingface" } dependencies = [ { name = "genkit" }, @@ -2604,7 +2572,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-mcp" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/mcp" } dependencies = [ { name = "genkit" }, @@ -2619,7 +2587,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-microsoft-foundry" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/microsoft-foundry" } dependencies = [ { name = "azure-identity" }, @@ -2640,7 +2608,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-mistral" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/mistral" } dependencies = [ { name = "genkit" }, @@ -2655,7 +2623,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-observability" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/observability" } dependencies = [ { name = "genkit" }, @@ -2685,7 +2653,7 @@ provides-extras = ["sentry"] [[package]] name = "genkit-plugin-ollama" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/ollama" } dependencies = [ { name = "genkit" }, @@ -2702,7 +2670,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-vertex-ai" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/vertex-ai" } dependencies = [ { name = "anthropic" }, @@ -2731,7 +2699,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-xai" -version = "0.5.0" +version = "0.5.1" source = { editable = "plugins/xai" } dependencies = [ { name = "genkit" }, @@ -2906,7 +2874,6 @@ lint = [ { name = "opentelemetry-instrumentation-fastapi" }, { name = "opentelemetry-instrumentation-grpc" }, { name = "pip-audit" }, - { name = "pydantic-settings" }, { name = "pypdf" }, { name = "pyrefly" }, { name = "pyright" }, @@ -2991,7 +2958,6 @@ lint = [ { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.41b0" }, { name = "opentelemetry-instrumentation-grpc", specifier = ">=0.41b0" }, { name = "pip-audit", specifier = ">=2.7.0" }, - { name = "pydantic-settings", specifier = ">=2.0.0" }, { name = "pypdf", specifier = ">=6.6.2" }, { name = "pyrefly", specifier = ">=0.15.0" }, { name = "pyright", specifier = ">=1.1.392" }, @@ -3016,18 +2982,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/5c/e226de133afd8bb267ec27eead9ae3d784b95b39a287ed404caab39a5f50/genson-1.3.0-py3-none-any.whl", hash = "sha256:468feccd00274cc7e4c09e84b08704270ba8d95232aa280f65b986139cec67f7", size = 21470, upload-time = "2024-05-15T22:08:47.056Z" }, ] -[[package]] -name = "ghp-import" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "python-dateutil" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d9/29/d40217cbe2f6b1359e00c6c307bb3fc876ba74068cbab3dde77f03ca0dc4/ghp-import-2.1.0.tar.gz", hash = "sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343", size = 10943, upload-time = "2022-05-02T15:47:16.11Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, -] - [[package]] name = "gitdb" version = "4.0.12" @@ -3423,38 +3377,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/2b/98c7f93e6db9977aaee07eb1e51ca63bd5f779b900d362791d3252e60558/greenlet-3.3.1-cp314-cp314t-win_amd64.whl", hash = "sha256:301860987846c24cb8964bdec0e31a96ad4a2a801b41b4ef40963c1b44f33451", size = 233181, upload-time = "2026-01-23T15:33:00.29Z" }, ] -[[package]] -name = "griffe" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "griffecli" }, - { name = "griffelib" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/94/ee21d41e7eb4f823b94603b9d40f86d3c7fde80eacc2c3c71845476dddaa/griffe-2.0.0-py3-none-any.whl", hash = "sha256:5418081135a391c3e6e757a7f3f156f1a1a746cc7b4023868ff7d5e2f9a980aa", size = 5214, upload-time = "2026-02-09T19:09:44.105Z" }, -] - -[[package]] -name = "griffecli" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama" }, - { name = "griffelib" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/ed/d93f7a447bbf7a935d8868e9617cbe1cadf9ee9ee6bd275d3040fbf93d60/griffecli-2.0.0-py3-none-any.whl", hash = "sha256:9f7cd9ee9b21d55e91689358978d2385ae65c22f307a63fb3269acf3f21e643d", size = 9345, upload-time = "2026-02-09T19:09:42.554Z" }, -] - -[[package]] -name = "griffelib" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, -] - [[package]] name = "grpc-google-iam-v1" version = "0.14.3" @@ -3557,69 +3479,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/83/8a/1241ec22c41028bddd4a052ae9369267b4475265ad0ce7140974548dc3fa/grpcio_status-1.78.0-py3-none-any.whl", hash = "sha256:b492b693d4bf27b47a6c32590701724f1d3b9444b36491878fb71f6208857f34", size = 14523, upload-time = "2026-02-06T10:01:32.584Z" }, ] -[[package]] -name = "grpcio-tools" -version = "1.78.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "grpcio" }, - { name = "protobuf" }, - { name = "setuptools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/8b/d1/cbefe328653f746fd319c4377836a25ba64226e41c6a1d7d5cdbc87a459f/grpcio_tools-1.78.0.tar.gz", hash = "sha256:4b0dd86560274316e155d925158276f8564508193088bc43e20d3f5dff956b2b", size = 5393026, upload-time = "2026-02-06T09:59:59.53Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/70/2118a814a62ab205c905d221064bc09021db83fceeb84764d35c00f0f633/grpcio_tools-1.78.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:ea64e38d1caa2b8468b08cb193f5a091d169b6dbfe1c7dac37d746651ab9d84e", size = 2545568, upload-time = "2026-02-06T09:57:30.308Z" }, - { url = "https://files.pythonhosted.org/packages/2b/a9/68134839dd1a00f964185ead103646d6dd6a396b92ed264eaf521431b793/grpcio_tools-1.78.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:4003fcd5cbb5d578b06176fd45883a72a8f9203152149b7c680ce28653ad9e3a", size = 5708704, upload-time = "2026-02-06T09:57:33.512Z" }, - { url = "https://files.pythonhosted.org/packages/36/1b/b6135aa9534e22051c53e5b9c0853d18024a41c50aaff464b7b47c1ed379/grpcio_tools-1.78.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fe6b0081775394c61ec633c9ff5dbc18337100eabb2e946b5c83967fe43b2748", size = 2591905, upload-time = "2026-02-06T09:57:35.338Z" }, - { url = "https://files.pythonhosted.org/packages/41/2b/6380df1390d62b1d18ae18d4d790115abf4997fa29498aa50ba644ecb9d8/grpcio_tools-1.78.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:7e989ad2cd93db52d7f1a643ecaa156ac55bf0484f1007b485979ce8aef62022", size = 2905271, upload-time = "2026-02-06T09:57:37.932Z" }, - { url = "https://files.pythonhosted.org/packages/3a/07/9b369f37c8f4956b68778c044d57390a8f0f3b1cca590018809e75a4fce2/grpcio_tools-1.78.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b874991797e96c41a37e563236c3317ed41b915eff25b292b202d6277d30da85", size = 2656234, upload-time = "2026-02-06T09:57:41.157Z" }, - { url = "https://files.pythonhosted.org/packages/51/61/40eee40e7a54f775a0d4117536532713606b6b177fff5e327f33ad18746e/grpcio_tools-1.78.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:daa8c288b728228377aaf758925692fc6068939d9fa32f92ca13dedcbeb41f33", size = 3105770, upload-time = "2026-02-06T09:57:43.373Z" }, - { url = "https://files.pythonhosted.org/packages/b6/ac/81ee4b728e70e8ba66a589f86469925ead02ed6f8973434e4a52e3576148/grpcio_tools-1.78.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:87e648759b06133199f4bc0c0053e3819f4ec3b900dc399e1097b6065db998b5", size = 3654896, upload-time = "2026-02-06T09:57:45.402Z" }, - { url = "https://files.pythonhosted.org/packages/be/b9/facb3430ee427c800bb1e39588c85685677ea649491d6e0874bd9f3a1c0e/grpcio_tools-1.78.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f3d3ced52bfe39eba3d24f5a8fab4e12d071959384861b41f0c52ca5399d6920", size = 3322529, upload-time = "2026-02-06T09:57:47.292Z" }, - { url = "https://files.pythonhosted.org/packages/c7/de/d7a011df9abfed8c30f0d2077b0562a6e3edc57cb3e5514718e2a81f370a/grpcio_tools-1.78.0-cp310-cp310-win32.whl", hash = "sha256:4bb6ed690d417b821808796221bde079377dff98fdc850ac157ad2f26cda7a36", size = 993518, upload-time = "2026-02-06T09:57:48.836Z" }, - { url = "https://files.pythonhosted.org/packages/c8/5e/f7f60c3ae2281c6b438c3a8455f4a5d5d2e677cf20207864cbee3763da22/grpcio_tools-1.78.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c676d8342fd53bd85a5d5f0d070cd785f93bc040510014708ede6fcb32fada1", size = 1158505, upload-time = "2026-02-06T09:57:50.633Z" }, - { url = "https://files.pythonhosted.org/packages/75/78/280184d19242ed6762bf453c47a70b869b3c5c72a24dc5bf2bf43909faa3/grpcio_tools-1.78.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:6a8b8b7b49f319d29dbcf507f62984fa382d1d10437d75c3f26db5f09c4ac0af", size = 2545904, upload-time = "2026-02-06T09:57:52.769Z" }, - { url = "https://files.pythonhosted.org/packages/5b/51/3c46dea5113f68fe879961cae62d34bb7a3c308a774301b45d614952ee98/grpcio_tools-1.78.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d62cf3b68372b0c6d722a6165db41b976869811abeabc19c8522182978d8db10", size = 5709078, upload-time = "2026-02-06T09:57:56.389Z" }, - { url = "https://files.pythonhosted.org/packages/e0/2c/dc1ae9ec53182c96d56dfcbf3bcd3e55a8952ad508b188c75bf5fc8993d4/grpcio_tools-1.78.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fa9056742efeaf89d5fe14198af71e5cbc4fbf155d547b89507e19d6025906c6", size = 2591744, upload-time = "2026-02-06T09:57:58.341Z" }, - { url = "https://files.pythonhosted.org/packages/04/63/9b53fc9a9151dd24386785171a4191ee7cb5afb4d983b6a6a87408f41b28/grpcio_tools-1.78.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:e3191af125dcb705aa6bc3856ba81ba99b94121c1b6ebee152e66ea084672831", size = 2905113, upload-time = "2026-02-06T09:58:00.38Z" }, - { url = "https://files.pythonhosted.org/packages/96/b2/0ad8d789f3a2a00893131c140865605fa91671a6e6fcf9da659e1fabba10/grpcio_tools-1.78.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:283239ddbb67ae83fac111c61b25d8527a1dbd355b377cbc8383b79f1329944d", size = 2656436, upload-time = "2026-02-06T09:58:03.038Z" }, - { url = "https://files.pythonhosted.org/packages/09/4d/580f47ce2fc61b093ade747b378595f51b4f59972dd39949f7444b464a03/grpcio_tools-1.78.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ac977508c0db15301ef36d6c79769ec1a6cc4e3bc75735afca7fe7e360cead3a", size = 3106128, upload-time = "2026-02-06T09:58:05.064Z" }, - { url = "https://files.pythonhosted.org/packages/c9/29/d83b2d89f8d10e438bad36b1eb29356510fb97e81e6a608b22ae1890e8e6/grpcio_tools-1.78.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4ff605e25652a0bd13aa8a73a09bc48669c68170902f5d2bf1468a57d5e78771", size = 3654953, upload-time = "2026-02-06T09:58:07.15Z" }, - { url = "https://files.pythonhosted.org/packages/08/71/917ce85633311e54fefd7e6eb1224fb780ef317a4d092766f5630c3fc419/grpcio_tools-1.78.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0197d7b561c79be78ab93d0fe2836c8def470683df594bae3ac89dd8e5c821b2", size = 3322630, upload-time = "2026-02-06T09:58:10.305Z" }, - { url = "https://files.pythonhosted.org/packages/b2/55/3fbf6b26ab46fc79e1e6f7f4e0993cf540263dad639290299fad374a0829/grpcio_tools-1.78.0-cp311-cp311-win32.whl", hash = "sha256:28f71f591f7f39555863ced84fcc209cbf4454e85ef957232f43271ee99af577", size = 993804, upload-time = "2026-02-06T09:58:13.698Z" }, - { url = "https://files.pythonhosted.org/packages/73/86/4affe006d9e1e9e1c6653d6aafe2f8b9188acb2b563cd8ed3a2c7c0e8aec/grpcio_tools-1.78.0-cp311-cp311-win_amd64.whl", hash = "sha256:5a6de495dabf86a3b40b9a7492994e1232b077af9d63080811838b781abbe4e8", size = 1158566, upload-time = "2026-02-06T09:58:15.721Z" }, - { url = "https://files.pythonhosted.org/packages/0c/ae/5b1fa5dd8d560a6925aa52de0de8731d319f121c276e35b9b2af7cc220a2/grpcio_tools-1.78.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:9eb122da57d4cad7d339fc75483116f0113af99e8d2c67f3ef9cae7501d806e4", size = 2546823, upload-time = "2026-02-06T09:58:17.944Z" }, - { url = "https://files.pythonhosted.org/packages/a7/ed/d33ccf7fa701512efea7e7e23333b748848a123e9d3bbafde4e126784546/grpcio_tools-1.78.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:d0c501b8249940b886420e6935045c44cb818fa6f265f4c2b97d5cff9cb5e796", size = 5706776, upload-time = "2026-02-06T09:58:20.944Z" }, - { url = "https://files.pythonhosted.org/packages/c6/69/4285583f40b37af28277fc6b867d636e3b10e1b6a7ebd29391a856e1279b/grpcio_tools-1.78.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:77e5aa2d2a7268d55b1b113f958264681ef1994c970f69d48db7d4683d040f57", size = 2593972, upload-time = "2026-02-06T09:58:23.29Z" }, - { url = "https://files.pythonhosted.org/packages/d7/eb/ecc1885bd6b3147f0a1b7dff5565cab72f01c8f8aa458f682a1c77a9fb08/grpcio_tools-1.78.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:8e3c0b0e6ba5275322ba29a97bf890565a55f129f99a21b121145e9e93a22525", size = 2905531, upload-time = "2026-02-06T09:58:25.406Z" }, - { url = "https://files.pythonhosted.org/packages/ae/a9/511d0040ced66960ca10ba0f082d6b2d2ee6dd61837b1709636fdd8e23b4/grpcio_tools-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:975d4cb48694e20ebd78e1643e5f1cd94cdb6a3d38e677a8e84ae43665aa4790", size = 2656909, upload-time = "2026-02-06T09:58:28.022Z" }, - { url = "https://files.pythonhosted.org/packages/06/a3/3d2c707e7dee8df842c96fbb24feb2747e506e39f4a81b661def7fed107c/grpcio_tools-1.78.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:553ff18c5d52807dedecf25045ae70bad7a3dbba0b27a9a3cdd9bcf0a1b7baec", size = 3109778, upload-time = "2026-02-06T09:58:30.091Z" }, - { url = "https://files.pythonhosted.org/packages/1f/4b/646811ba241bf05da1f0dc6f25764f1c837f78f75b4485a4210c84b79eae/grpcio_tools-1.78.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8c7f5e4af5a84d2e96c862b1a65e958a538237e268d5f8203a3a784340975b51", size = 3658763, upload-time = "2026-02-06T09:58:32.875Z" }, - { url = "https://files.pythonhosted.org/packages/45/de/0a5ef3b3e79d1011375f5580dfee3a9c1ccb96c5f5d1c74c8cee777a2483/grpcio_tools-1.78.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:96183e2b44afc3f9a761e9d0f985c3b44e03e8bb98e626241a6cbfb3b6f7e88f", size = 3325116, upload-time = "2026-02-06T09:58:34.894Z" }, - { url = "https://files.pythonhosted.org/packages/95/d2/6391b241ad571bc3e71d63f957c0b1860f0c47932d03c7f300028880f9b8/grpcio_tools-1.78.0-cp312-cp312-win32.whl", hash = "sha256:2250e8424c565a88573f7dc10659a0b92802e68c2a1d57e41872c9b88ccea7a6", size = 993493, upload-time = "2026-02-06T09:58:37.242Z" }, - { url = "https://files.pythonhosted.org/packages/7c/8f/7d0d3a39ecad76ccc136be28274daa660569b244fa7d7d0bbb24d68e5ece/grpcio_tools-1.78.0-cp312-cp312-win_amd64.whl", hash = "sha256:217d1fa29de14d9c567d616ead7cb0fef33cde36010edff5a9390b00d52e5094", size = 1158423, upload-time = "2026-02-06T09:58:40.072Z" }, - { url = "https://files.pythonhosted.org/packages/53/ce/17311fb77530420e2f441e916b347515133e83d21cd6cc77be04ce093d5b/grpcio_tools-1.78.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:2d6de1cc23bdc1baafc23e201b1e48c617b8c1418b4d8e34cebf72141676e5fb", size = 2546284, upload-time = "2026-02-06T09:58:43.073Z" }, - { url = "https://files.pythonhosted.org/packages/1d/d3/79e101483115f0e78223397daef71751b75eba7e92a32060c10aae11ca64/grpcio_tools-1.78.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:2afeaad88040894c76656202ff832cb151bceb05c0e6907e539d129188b1e456", size = 5705653, upload-time = "2026-02-06T09:58:45.533Z" }, - { url = "https://files.pythonhosted.org/packages/8b/a7/52fa3ccb39ceeee6adc010056eadfbca8198651c113e418dafebbdf2b306/grpcio_tools-1.78.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:33cc593735c93c03d63efe7a8ba25f3c66f16c52f0651910712490244facad72", size = 2592788, upload-time = "2026-02-06T09:58:48.918Z" }, - { url = "https://files.pythonhosted.org/packages/68/08/682ff6bb548225513d73dc9403742d8975439d7469c673bc534b9bbc83a7/grpcio_tools-1.78.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:2921d7989c4d83b71f03130ab415fa4d66e6693b8b8a1fcbb7a1c67cff19b812", size = 2905157, upload-time = "2026-02-06T09:58:51.478Z" }, - { url = "https://files.pythonhosted.org/packages/b2/66/264f3836a96423b7018e5ada79d62576a6401f6da4e1f4975b18b2be1265/grpcio_tools-1.78.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e6a0df438e82c804c7b95e3f311c97c2f876dcc36376488d5b736b7bcf5a9b45", size = 2656166, upload-time = "2026-02-06T09:58:54.117Z" }, - { url = "https://files.pythonhosted.org/packages/f3/6b/f108276611522e03e98386b668cc7e575eff6952f2db9caa15b2a3b3e883/grpcio_tools-1.78.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e9c6070a9500798225191ef25d0055a15d2c01c9c8f2ee7b681fffa99c98c822", size = 3109110, upload-time = "2026-02-06T09:58:56.891Z" }, - { url = "https://files.pythonhosted.org/packages/6f/c7/cf048dbcd64b3396b3c860a2ffbcc67a8f8c87e736aaa74c2e505a7eee4c/grpcio_tools-1.78.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:394e8b57d85370a62e5b0a4d64c96fcf7568345c345d8590c821814d227ecf1d", size = 3657863, upload-time = "2026-02-06T09:58:59.176Z" }, - { url = "https://files.pythonhosted.org/packages/b6/37/e2736912c8fda57e2e57a66ea5e0bc8eb9a5fb7ded00e866ad22d50afb08/grpcio_tools-1.78.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a3ef700293ab375e111a2909d87434ed0a0b086adf0ce67a8d9cf12ea7765e63", size = 3324748, upload-time = "2026-02-06T09:59:01.242Z" }, - { url = "https://files.pythonhosted.org/packages/1c/5d/726abc75bb5bfc2841e88ea05896e42f51ca7c30cb56da5c5b63058b3867/grpcio_tools-1.78.0-cp313-cp313-win32.whl", hash = "sha256:6993b960fec43a8d840ee5dc20247ef206c1a19587ea49fe5e6cc3d2a09c1585", size = 993074, upload-time = "2026-02-06T09:59:03.085Z" }, - { url = "https://files.pythonhosted.org/packages/c5/68/91b400bb360faf9b177ffb5540ec1c4d06ca923691ddf0f79e2c9683f4da/grpcio_tools-1.78.0-cp313-cp313-win_amd64.whl", hash = "sha256:275ce3c2978842a8cf9dd88dce954e836e590cf7029649ad5d1145b779039ed5", size = 1158185, upload-time = "2026-02-06T09:59:05.036Z" }, - { url = "https://files.pythonhosted.org/packages/cf/5e/278f3831c8d56bae02e3acc570465648eccf0a6bbedcb1733789ac966803/grpcio_tools-1.78.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:8b080d0d072e6032708a3a91731b808074d7ab02ca8fb9847b6a011fdce64cd9", size = 2546270, upload-time = "2026-02-06T09:59:07.426Z" }, - { url = "https://files.pythonhosted.org/packages/a3/d9/68582f2952b914b60dddc18a2e3f9c6f09af9372b6f6120d6cf3ec7f8b4e/grpcio_tools-1.78.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8c0ad8f8f133145cd7008b49cb611a5c6a9d89ab276c28afa17050516e801f79", size = 5705731, upload-time = "2026-02-06T09:59:09.856Z" }, - { url = "https://files.pythonhosted.org/packages/70/68/feb0f9a48818ee1df1e8b644069379a1e6ef5447b9b347c24e96fd258e5d/grpcio_tools-1.78.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2f8ea092a7de74c6359335d36f0674d939a3c7e1a550f4c2c9e80e0226de8fe4", size = 2593896, upload-time = "2026-02-06T09:59:12.23Z" }, - { url = "https://files.pythonhosted.org/packages/1f/08/a430d8d06e1b8d33f3e48d3f0cc28236723af2f35e37bd5c8db05df6c3aa/grpcio_tools-1.78.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:da422985e0cac822b41822f43429c19ecb27c81ffe3126d0b74e77edec452608", size = 2905298, upload-time = "2026-02-06T09:59:14.458Z" }, - { url = "https://files.pythonhosted.org/packages/71/0a/348c36a3eae101ca0c090c9c3bc96f2179adf59ee0c9262d11cdc7bfe7db/grpcio_tools-1.78.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4fab1faa3fbcb246263e68da7a8177d73772283f9db063fb8008517480888d26", size = 2656186, upload-time = "2026-02-06T09:59:16.949Z" }, - { url = "https://files.pythonhosted.org/packages/1d/3f/18219f331536fad4af6207ade04142292faa77b5cb4f4463787988963df8/grpcio_tools-1.78.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:dd9c094f73f734becae3f20f27d4944d3cd8fb68db7338ee6c58e62fc5c3d99f", size = 3109859, upload-time = "2026-02-06T09:59:19.202Z" }, - { url = "https://files.pythonhosted.org/packages/5b/d9/341ea20a44c8e5a3a18acc820b65014c2e3ea5b4f32a53d14864bcd236bc/grpcio_tools-1.78.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:2ed51ce6b833068f6c580b73193fc2ec16468e6bc18354bc2f83a58721195a58", size = 3657915, upload-time = "2026-02-06T09:59:21.839Z" }, - { url = "https://files.pythonhosted.org/packages/fb/f4/5978b0f91611a64371424c109dd0027b247e5b39260abad2eaee66b6aa37/grpcio_tools-1.78.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:05803a5cdafe77c8bdf36aa660ad7a6a1d9e49bc59ce45c1bade2a4698826599", size = 3324724, upload-time = "2026-02-06T09:59:24.402Z" }, - { url = "https://files.pythonhosted.org/packages/b2/80/96a324dba99cfbd20e291baf0b0ae719dbb62b76178c5ce6c788e7331cb1/grpcio_tools-1.78.0-cp314-cp314-win32.whl", hash = "sha256:f7c722e9ce6f11149ac5bddd5056e70aaccfd8168e74e9d34d8b8b588c3f5c7c", size = 1015505, upload-time = "2026-02-06T09:59:26.3Z" }, - { url = "https://files.pythonhosted.org/packages/3b/d1/909e6a05bfd44d46327dc4b8a78beb2bae4fb245ffab2772e350081aaf7e/grpcio_tools-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:7d58ade518b546120ec8f0a8e006fc8076ae5df151250ebd7e82e9b5e152c229", size = 1190196, upload-time = "2026-02-06T09:59:28.359Z" }, -] - [[package]] name = "gunicorn" version = "25.1.0" @@ -4227,19 +4086,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, ] -[[package]] -name = "jsbeautifier" -version = "1.15.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "editorconfig" }, - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ea/98/d6cadf4d5a1c03b2136837a435682418c29fdeb66be137128544cecc5b7a/jsbeautifier-1.15.4.tar.gz", hash = "sha256:5bb18d9efb9331d825735fbc5360ee8f1aac5e52780042803943aa7f854f7592", size = 75257, upload-time = "2025-02-27T17:53:53.252Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/14/1c65fccf8413d5f5c6e8425f84675169654395098000d8bddc4e9d3390e1/jsbeautifier-1.15.4-py3-none-any.whl", hash = "sha256:72f65de312a3f10900d7685557f84cb61a9733c50dcc27271a39f5b0051bf528", size = 94707, upload-time = "2025-02-27T17:53:46.152Z" }, -] - [[package]] name = "json5" version = "0.13.0" @@ -4679,15 +4525,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f2/24/8d99982f0aa9c1cd82073c6232b54a0dbe6797c7d63c0583a6c68ee3ddf2/litestar_htmx-0.5.0-py3-none-any.whl", hash = "sha256:92833aa47e0d0e868d2a7dbfab75261f124f4b83d4f9ad12b57b9a68f86c50e6", size = 9970, upload-time = "2025-06-11T21:19:44.465Z" }, ] -[[package]] -name = "markdown" -version = "3.10.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805, upload-time = "2026-02-09T14:57:26.942Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180, upload-time = "2026-02-09T14:57:25.787Z" }, -] - [[package]] name = "markdown-it-py" version = "4.0.0" @@ -4831,15 +4668,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] -[[package]] -name = "mergedeep" -version = "1.3.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/41/580bb4006e3ed0361b8151a01d324fb03f420815446c7def45d02f74c270/mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8", size = 4661, upload-time = "2021-02-05T18:55:30.623Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/19/04f9b178c2d8a15b076c8b5140708fa6ffc5601fb6f1e975537072df5b2a/mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307", size = 6354, upload-time = "2021-02-05T18:55:29.583Z" }, -] - [[package]] name = "mistralai" version = "1.12.2" @@ -4873,157 +4701,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" }, ] -[[package]] -name = "mkdocs" -version = "1.6.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "ghp-import" }, - { name = "jinja2" }, - { name = "markdown" }, - { name = "markupsafe" }, - { name = "mergedeep" }, - { name = "mkdocs-get-deps" }, - { name = "packaging" }, - { name = "pathspec" }, - { name = "pyyaml" }, - { name = "pyyaml-env-tag" }, - { name = "watchdog" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/bc/c6/bbd4f061bd16b378247f12953ffcb04786a618ce5e904b8c5a01a0309061/mkdocs-1.6.1.tar.gz", hash = "sha256:7b432f01d928c084353ab39c57282f29f92136665bdd6abf7c1ec8d822ef86f2", size = 3889159, upload-time = "2024-08-30T12:24:06.899Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/22/5b/dbc6a8cddc9cfa9c4971d59fb12bb8d42e161b7e7f8cc89e49137c5b279c/mkdocs-1.6.1-py3-none-any.whl", hash = "sha256:db91759624d1647f3f34aa0c3f327dd2601beae39a366d6e064c03468d35c20e", size = 3864451, upload-time = "2024-08-30T12:24:05.054Z" }, -] - -[[package]] -name = "mkdocs-autorefs" -version = "1.4.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markdown" }, - { name = "markupsafe" }, - { name = "mkdocs" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/52/c0/f641843de3f612a6b48253f39244165acff36657a91cc903633d456ae1ac/mkdocs_autorefs-1.4.4.tar.gz", hash = "sha256:d54a284f27a7346b9c38f1f852177940c222da508e66edc816a0fa55fc6da197", size = 56588, upload-time = "2026-02-10T15:23:55.105Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/28/de/a3e710469772c6a89595fc52816da05c1e164b4c866a89e3cb82fb1b67c5/mkdocs_autorefs-1.4.4-py3-none-any.whl", hash = "sha256:834ef5408d827071ad1bc69e0f39704fa34c7fc05bc8e1c72b227dfdc5c76089", size = 25530, upload-time = "2026-02-10T15:23:53.817Z" }, -] - -[[package]] -name = "mkdocs-awesome-pages-plugin" -version = "2.10.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mkdocs" }, - { name = "natsort" }, - { name = "wcmatch" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/92/e8/6ae9c18d8174a5d74ce4ade7a7f4c350955063968bc41ff1e5833cff4a2b/mkdocs_awesome_pages_plugin-2.10.1.tar.gz", hash = "sha256:cda2cb88c937ada81a4785225f20ef77ce532762f4500120b67a1433c1cdbb2f", size = 16303, upload-time = "2024-12-22T21:13:49.19Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/73/61/19fc1e9c579dbfd4e8a402748f1d63cab7aabe8f8d91eb0235e45b32d040/mkdocs_awesome_pages_plugin-2.10.1-py3-none-any.whl", hash = "sha256:c6939dbea37383fc3cf8c0a4e892144ec3d2f8a585e16fdc966b34e7c97042a7", size = 15118, upload-time = "2024-12-22T21:13:46.945Z" }, -] - -[[package]] -name = "mkdocs-get-deps" -version = "0.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mergedeep" }, - { name = "platformdirs" }, - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/98/f5/ed29cd50067784976f25ed0ed6fcd3c2ce9eb90650aa3b2796ddf7b6870b/mkdocs_get_deps-0.2.0.tar.gz", hash = "sha256:162b3d129c7fad9b19abfdcb9c1458a651628e4b1dea628ac68790fb3061c60c", size = 10239, upload-time = "2023-11-20T17:51:09.981Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/d4/029f984e8d3f3b6b726bd33cafc473b75e9e44c0f7e80a5b29abc466bdea/mkdocs_get_deps-0.2.0-py3-none-any.whl", hash = "sha256:2bf11d0b133e77a0dd036abeeb06dec8775e46efa526dc70667d8863eefc6134", size = 9521, upload-time = "2023-11-20T17:51:08.587Z" }, -] - -[[package]] -name = "mkdocs-material" -version = "9.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "babel" }, - { name = "backrefs" }, - { name = "colorama" }, - { name = "jinja2" }, - { name = "markdown" }, - { name = "mkdocs" }, - { name = "mkdocs-material-extensions" }, - { name = "paginate" }, - { name = "pygments" }, - { name = "pymdown-extensions" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/27/e2/2ffc356cd72f1473d07c7719d82a8f2cbd261666828614ecb95b12169f41/mkdocs_material-9.7.1.tar.gz", hash = "sha256:89601b8f2c3e6c6ee0a918cc3566cb201d40bf37c3cd3c2067e26fadb8cce2b8", size = 4094392, upload-time = "2025-12-18T09:49:00.308Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/32/ed071cb721aca8c227718cffcf7bd539620e9799bbf2619e90c757bfd030/mkdocs_material-9.7.1-py3-none-any.whl", hash = "sha256:3f6100937d7d731f87f1e3e3b021c97f7239666b9ba1151ab476cabb96c60d5c", size = 9297166, upload-time = "2025-12-18T09:48:56.664Z" }, -] - -[[package]] -name = "mkdocs-material-extensions" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/79/9b/9b4c96d6593b2a541e1cb8b34899a6d021d208bb357042823d4d2cabdbe7/mkdocs_material_extensions-1.3.1.tar.gz", hash = "sha256:10c9511cea88f568257f960358a467d12b970e1f7b2c0e5fb2bb48cab1928443", size = 11847, upload-time = "2023-11-22T19:09:45.208Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5b/54/662a4743aa81d9582ee9339d4ffa3c8fd40a4965e033d77b9da9774d3960/mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31", size = 8728, upload-time = "2023-11-22T19:09:43.465Z" }, -] - -[[package]] -name = "mkdocs-mermaid2-plugin" -version = "1.2.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "beautifulsoup4" }, - { name = "jsbeautifier" }, - { name = "mkdocs" }, - { name = "pymdown-extensions" }, - { name = "requests" }, - { name = "setuptools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/2a/6d/308f443a558b6a97ce55782658174c0d07c414405cfc0a44d36ad37e36f9/mkdocs_mermaid2_plugin-1.2.3.tar.gz", hash = "sha256:fb6f901d53e5191e93db78f93f219cad926ccc4d51e176271ca5161b6cc5368c", size = 16220, upload-time = "2025-10-17T19:38:53.047Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1a/4b/6fd6dd632019b7f522f1b1f794ab6115cd79890330986614be56fd18f0eb/mkdocs_mermaid2_plugin-1.2.3-py3-none-any.whl", hash = "sha256:33f60c582be623ed53829a96e19284fc7f1b74a1dbae78d4d2e47fe00c3e190d", size = 17299, upload-time = "2025-10-17T19:38:51.874Z" }, -] - -[[package]] -name = "mkdocstrings" -version = "1.0.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jinja2" }, - { name = "markdown" }, - { name = "markupsafe" }, - { name = "mkdocs" }, - { name = "mkdocs-autorefs" }, - { name = "pymdown-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/46/62/0dfc5719514115bf1781f44b1d7f2a0923fcc01e9c5d7990e48a05c9ae5d/mkdocstrings-1.0.3.tar.gz", hash = "sha256:ab670f55040722b49bb45865b2e93b824450fb4aef638b00d7acb493a9020434", size = 100946, upload-time = "2026-02-07T14:31:40.973Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/41/1cf02e3df279d2dd846a1bf235a928254eba9006dd22b4a14caa71aed0f7/mkdocstrings-1.0.3-py3-none-any.whl", hash = "sha256:0d66d18430c2201dc7fe85134277382baaa15e6b30979f3f3bdbabd6dbdb6046", size = 35523, upload-time = "2026-02-07T14:31:39.27Z" }, -] - -[package.optional-dependencies] -python = [ - { name = "mkdocstrings-python" }, -] - -[[package]] -name = "mkdocstrings-python" -version = "2.0.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "griffe" }, - { name = "mkdocs-autorefs" }, - { name = "mkdocstrings" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/25/84/78243847ad9d5c21d30a2842720425b17e880d99dfe824dee11d6b2149b4/mkdocstrings_python-2.0.2.tar.gz", hash = "sha256:4a32ccfc4b8d29639864698e81cfeb04137bce76bb9f3c251040f55d4b6e1ad8", size = 199124, upload-time = "2026-02-09T15:12:01.543Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/31/7ee938abbde2322e553a2cb5f604cdd1e4728e08bba39c7ee6fae9af840b/mkdocstrings_python-2.0.2-py3-none-any.whl", hash = "sha256:31241c0f43d85a69306d704d5725786015510ea3f3c4bdfdb5a5731d83cdc2b0", size = 104900, upload-time = "2026-02-09T15:12:00.166Z" }, -] - [[package]] name = "more-itertools" version = "10.8.0" @@ -5403,15 +5080,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/cc/7cb74758e6df95e0c4e1253f203b6dd7f348bf2f29cf89e9210a2416d535/narwhals-2.16.0-py3-none-any.whl", hash = "sha256:846f1fd7093ac69d63526e50732033e86c30ea0026a44d9b23991010c7d1485d", size = 443951, upload-time = "2026-02-02T10:30:58.635Z" }, ] -[[package]] -name = "natsort" -version = "8.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e2/a9/a0c57aee75f77794adaf35322f8b6404cbd0f89ad45c87197a937764b7d0/natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581", size = 76575, upload-time = "2023-06-20T04:17:19.925Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/82/7a9d0550484a62c6da82858ee9419f3dd1ccc9aa1c26a1e43da3ecd20b0d/natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c", size = 38268, upload-time = "2023-06-20T04:17:17.522Z" }, -] - [[package]] name = "nbclient" version = "0.10.4" @@ -6049,15 +5717,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, ] -[[package]] -name = "paginate" -version = "0.5.7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ec/46/68dde5b6bc00c1296ec6466ab27dddede6aec9af1b99090e1107091b3b84/paginate-0.5.7.tar.gz", hash = "sha256:22bd083ab41e1a8b4f3690544afb2c60c25e5c9a63a30fa2f483f6c60c8e5945", size = 19252, upload-time = "2024-08-25T14:17:24.139Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746, upload-time = "2024-08-25T14:17:22.55Z" }, -] - [[package]] name = "pandas" version = "2.3.3" @@ -6317,19 +5976,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/f3/4888f895c02afa085630a3a3329d1b18b998874642ad4c530e9a4d7851fe/pip_audit-2.10.0-py3-none-any.whl", hash = "sha256:16e02093872fac97580303f0848fa3ad64f7ecf600736ea7835a2b24de49613f", size = 61518, upload-time = "2025-12-01T23:42:39.193Z" }, ] -[[package]] -name = "pip-licenses" -version = "5.5.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "prettytable" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/44/4c/b4be9024dae3b5b3c0a6c58cc1d4a35fffe51c3adb835350cb7dcd43b5cd/pip_licenses-5.5.1.tar.gz", hash = "sha256:7df370e6e5024a3f7449abf8e4321ef868ba9a795698ad24ab6851f3e7fc65a7", size = 49108, upload-time = "2026-01-27T21:46:41.432Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/a3/0b369cdffef3746157712804f1ded9856c75aa060217ee206f742c74e753/pip_licenses-5.5.1-py3-none-any.whl", hash = "sha256:ed5e229a93760e529cfa7edaec6630b5a2cd3874c1bddb8019e5f18a723fdead", size = 22108, upload-time = "2026-01-27T21:46:39.766Z" }, -] - [[package]] name = "pip-requirements-parser" version = "32.0.1" @@ -6388,18 +6034,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/21/93363d7b802aa904f8d4169bc33e0e316d06d26ee68d40fe0355057da98c/polyfactory-3.2.0-py3-none-any.whl", hash = "sha256:5945799cce4c56cd44ccad96fb0352996914553cc3efaa5a286930599f569571", size = 62181, upload-time = "2025-12-21T11:18:49.311Z" }, ] -[[package]] -name = "prettytable" -version = "3.17.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "wcwidth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/79/45/b0847d88d6cfeb4413566738c8bbf1e1995fad3d42515327ff32cc1eb578/prettytable-3.17.0.tar.gz", hash = "sha256:59f2590776527f3c9e8cf9fe7b66dd215837cca96a9c39567414cbc632e8ddb0", size = 67892, upload-time = "2025-11-14T17:33:20.212Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/8c/83087ebc47ab0396ce092363001fa37c17153119ee282700c0713a195853/prettytable-3.17.0-py3-none-any.whl", hash = "sha256:aad69b294ddbe3e1f95ef8886a060ed1666a0b83018bbf56295f6f226c43d287", size = 34433, upload-time = "2025-11-14T17:33:19.093Z" }, -] - [[package]] name = "priority" version = "2.0.0" @@ -7603,19 +7237,6 @@ crypto = [ { name = "cryptography" }, ] -[[package]] -name = "pymdown-extensions" -version = "10.20.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markdown" }, - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1e/6c/9e370934bfa30e889d12e61d0dae009991294f40055c238980066a7fbd83/pymdown_extensions-10.20.1.tar.gz", hash = "sha256:e7e39c865727338d434b55f1dd8da51febcffcaebd6e1a0b9c836243f660740a", size = 852860, upload-time = "2026-01-24T05:56:56.758Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/6d/b6ee155462a0156b94312bdd82d2b92ea56e909740045a87ccb98bf52405/pymdown_extensions-10.20.1-py3-none-any.whl", hash = "sha256:24af7feacbca56504b313b7b418c4f5e1317bb5fea60f03d57be7fcc40912aa0", size = 268768, upload-time = "2026-01-24T05:56:54.537Z" }, -] - [[package]] name = "pyparsing" version = "3.3.2" @@ -8006,18 +7627,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] -[[package]] -name = "pyyaml-env-tag" -version = "1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/eb/2e/79c822141bfd05a853236b504869ebc6b70159afc570e1d5a20641782eaa/pyyaml_env_tag-1.1.tar.gz", hash = "sha256:2eb38b75a2d21ee0475d6d97ec19c63287a7e140231e4214969d0eac923cd7ff", size = 5737, upload-time = "2025-05-13T15:24:01.64Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04", size = 4722, upload-time = "2025-05-13T15:23:59.629Z" }, -] - [[package]] name = "pyzmq" version = "27.1.0" @@ -8497,18 +8106,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ca/63/2c6daf59d86b1c30600bff679d039f57fd1932af82c43c0bde1cbc55e8d4/sentry_sdk-2.52.0-py2.py3-none-any.whl", hash = "sha256:931c8f86169fc6f2752cb5c4e6480f0d516112e78750c312e081ababecbaf2ed", size = 435547, upload-time = "2026-02-04T15:03:51.567Z" }, ] -[package.optional-dependencies] -fastapi = [ - { name = "fastapi" }, -] -litestar = [ - { name = "litestar" }, -] -quart = [ - { name = "blinker" }, - { name = "quart" }, -] - [[package]] name = "setuptools" version = "81.0.0" @@ -9320,18 +8917,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/d4/ed38dd3b1767193de971e694aa544356e63353c33a85d948166b5ff58b9e/watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6f39af2eab0118338902798b5aa6664f46ff66bc0280de76fca67a7f262a49", size = 457546, upload-time = "2025-10-14T15:06:13.372Z" }, ] -[[package]] -name = "wcmatch" -version = "10.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "bracex" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/79/3e/c0bdc27cf06f4e47680bd5803a07cb3dfd17de84cde92dd217dcb9e05253/wcmatch-10.1.tar.gz", hash = "sha256:f11f94208c8c8484a16f4f48638a85d771d9513f4ab3f37595978801cb9465af", size = 117421, upload-time = "2025-06-22T19:14:02.49Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/d8/0d1d2e9d3fabcf5d6840362adcf05f8cf3cd06a73358140c3a97189238ae/wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a", size = 39854, upload-time = "2025-06-22T19:14:00.978Z" }, -] - [[package]] name = "wcwidth" version = "0.6.0" @@ -9341,102 +8926,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, ] -[package.optional-dependencies] -aws = [ - { name = "genkit-plugin-amazon-bedrock" }, -] -azure = [ - { name = "genkit-plugin-microsoft-foundry" }, -] -dev = [ - { name = "liccheck" }, - { name = "pip-audit" }, - { name = "pip-licenses" }, - { name = "pyrefly" }, - { name = "pyright" }, - { name = "pysentry-rs" }, - { name = "ruff" }, - { name = "sentry-sdk", extra = ["fastapi", "litestar", "quart"] }, - { name = "ty" }, - { name = "watchdog" }, -] -docs = [ - { name = "mkdocs-awesome-pages-plugin" }, - { name = "mkdocs-material" }, - { name = "mkdocs-mermaid2-plugin" }, - { name = "mkdocstrings", extra = ["python"] }, -] -gcp = [ - { name = "genkit-plugin-google-cloud" }, -] -observability = [ - { name = "genkit-plugin-observability" }, -] -sentry = [ - { name = "sentry-sdk", extra = ["fastapi", "litestar", "quart"] }, -] -test = [ - { name = "httpx" }, - { name = "opentelemetry-api" }, - { name = "opentelemetry-instrumentation-fastapi" }, - { name = "opentelemetry-sdk" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, -] - -[package.metadata] -requires-dist = [ - { name = "fastapi", specifier = ">=0.115.0" }, - { name = "genkit", editable = "packages/genkit" }, - { name = "genkit-plugin-amazon-bedrock", marker = "extra == 'aws'", editable = "plugins/amazon-bedrock" }, - { name = "genkit-plugin-google-cloud", marker = "extra == 'gcp'", editable = "plugins/google-cloud" }, - { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, - { name = "genkit-plugin-microsoft-foundry", marker = "extra == 'azure'", editable = "plugins/microsoft-foundry" }, - { name = "genkit-plugin-observability", marker = "extra == 'observability'", editable = "plugins/observability" }, - { name = "grpcio", specifier = ">=1.68.0" }, - { name = "grpcio-reflection", specifier = ">=1.68.0" }, - { name = "grpcio-tools", specifier = ">=1.68.0" }, - { name = "gunicorn", specifier = ">=22.0.0" }, - { name = "httpx", marker = "extra == 'test'", specifier = ">=0.27.0" }, - { name = "hypercorn", specifier = ">=0.17.0" }, - { name = "liccheck", marker = "extra == 'dev'", specifier = ">=0.9.2" }, - { name = "litestar", specifier = ">=2.20.0" }, - { name = "mkdocs-awesome-pages-plugin", marker = "extra == 'docs'", specifier = ">=2.9.0" }, - { name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.6.0" }, - { name = "mkdocs-mermaid2-plugin", marker = "extra == 'docs'", specifier = ">=1.1.0" }, - { name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = ">=0.27.0" }, - { name = "opentelemetry-api", specifier = ">=1.20.0" }, - { name = "opentelemetry-api", marker = "extra == 'test'", specifier = ">=1.20.0" }, - { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = ">=1.20.0" }, - { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.20.0" }, - { name = "opentelemetry-instrumentation-asgi", specifier = ">=0.41b0" }, - { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.41b0" }, - { name = "opentelemetry-instrumentation-fastapi", marker = "extra == 'test'", specifier = ">=0.41b0" }, - { name = "opentelemetry-instrumentation-grpc", specifier = ">=0.41b0" }, - { name = "opentelemetry-sdk", specifier = ">=1.20.0" }, - { name = "opentelemetry-sdk", marker = "extra == 'test'", specifier = ">=1.20.0" }, - { name = "pip-audit", marker = "extra == 'dev'", specifier = ">=2.7.0" }, - { name = "pip-licenses", marker = "extra == 'dev'", specifier = ">=5.0.0" }, - { name = "pydantic-settings", specifier = ">=2.0.0" }, - { name = "pyrefly", marker = "extra == 'dev'", specifier = ">=0.15.0" }, - { name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.392" }, - { name = "pysentry-rs", marker = "extra == 'dev'", specifier = ">=0.3.14" }, - { name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" }, - { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.24.0" }, - { name = "quart", specifier = ">=0.19.0" }, - { name = "rich", specifier = ">=13.0.0" }, - { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.11.0" }, - { name = "secure", specifier = ">=1.0.0" }, - { name = "sentry-sdk", extras = ["fastapi", "litestar", "quart", "grpc"], marker = "extra == 'dev'", specifier = ">=2.0.0" }, - { name = "sentry-sdk", extras = ["fastapi", "litestar", "quart", "grpc"], marker = "extra == 'sentry'", specifier = ">=2.0.0" }, - { name = "structlog", specifier = ">=24.0.0" }, - { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.1" }, - { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" }, - { name = "uvloop", specifier = ">=0.21.0" }, - { name = "watchdog", marker = "extra == 'dev'", specifier = ">=6.0.0" }, -] -provides-extras = ["aws", "azure", "dev", "docs", "gcp", "observability", "sentry", "test"] - [[package]] name = "web-fastapi-bugbot" version = "0.2.0" From 51d9d306e0236f1604d86b19d3049d6492c23422 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 17 Apr 2026 08:22:41 -0700 Subject: [PATCH 058/141] regenerate genkit-schema.json --- genkit-tools/genkit-schema.json | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 56fd2ad3e8..2ceb970ae8 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -105,15 +105,6 @@ }, "additionalProperties": false }, - "TurnEnd": { - "type": "object", - "properties": { - "snapshotId": { - "type": "string" - } - }, - "additionalProperties": false - }, "SessionState": { "type": "object", "properties": { @@ -141,6 +132,15 @@ "invocationEnd" ] }, + "TurnEnd": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + } + }, + "additionalProperties": false + }, "DocumentData": { "type": "object", "properties": { From 7fd8f610a0fe5f1f95af6e266f8febb539cc8e76 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 17 Apr 2026 08:33:36 -0700 Subject: [PATCH 059/141] regenerate Python schema typing --- py/packages/genkit/src/genkit/core/typing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 403372e500..585aa310ee 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -39,13 +39,6 @@ class Model(RootModel[Any]): root: Any -class TurnEnd(BaseModel): - """Model for turnend data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str | None = Field(default=None) - - class SnapshotEvent(StrEnum): """SnapshotEvent data type class.""" @@ -53,6 +46,13 @@ class SnapshotEvent(StrEnum): INVOCATION_END = 'invocationEnd' +class TurnEnd(BaseModel): + """Model for turnend data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str | None = Field(default=None) + + class Embedding(BaseModel): """Model for embedding data.""" From 402d64e3ba9cfa72b30b2746ee388952f7d93975 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 21 Apr 2026 09:41:28 -0700 Subject: [PATCH 060/141] Update session_flow.go --- go/ai/exp/session_flow.go | 70 +++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/go/ai/exp/session_flow.go b/go/ai/exp/session_flow.go index 7f571afbf1..d7c6b8e89a 100644 --- a/go/ai/exp/session_flow.go +++ b/go/ai/exp/session_flow.go @@ -65,27 +65,27 @@ type SessionRunner[State any] struct { // wrapped in a trace span for observability. Input messages are automatically // added to the session before fn is called. After fn returns successfully, a // TurnEnd chunk is sent and a snapshot check is triggered. -func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error { - for input := range a.InputCh { +func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error { + for input := range s.InputCh { spanMeta := &tracing.SpanMetadata{ - Name: fmt.Sprintf("sessionFlow/turn/%d", a.TurnIndex), + Name: fmt.Sprintf("sessionFlow/turn/%d", s.TurnIndex), Type: "flowStep", Subtype: "flowStep", } _, err := tracing.RunInNewSpan(ctx, spanMeta, input, func(ctx context.Context, input *SessionFlowInput) (any, error) { - a.AddMessages(input.Messages...) + s.AddMessages(input.Messages...) if err := fn(ctx, input); err != nil { return nil, err } - a.onEndTurn(ctx) - a.TurnIndex++ + s.onEndTurn(ctx) + s.TurnIndex++ - if a.collectTurnOutput != nil { - return a.collectTurnOutput(), nil + if s.collectTurnOutput != nil { + return s.collectTurnOutput(), nil } return nil, nil }, @@ -101,17 +101,17 @@ func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont // the last message in the conversation history and all artifacts. // It is a convenience for custom session flows that don't need to construct the // result manually. -func (a *SessionRunner[State]) Result() *SessionFlowResult { - a.mu.RLock() - defer a.mu.RUnlock() +func (s *SessionRunner[State]) Result() *SessionFlowResult { + s.mu.RLock() + defer s.mu.RUnlock() result := &SessionFlowResult{} - if msgs := a.state.Messages; len(msgs) > 0 { + if msgs := s.state.Messages; len(msgs) > 0 { result.Message = msgs[len(msgs)-1] } - if len(a.state.Artifacts) > 0 { - arts := make([]*Artifact, len(a.state.Artifacts)) - copy(arts, a.state.Artifacts) + if len(s.state.Artifacts) > 0 { + arts := make([]*Artifact, len(s.state.Artifacts)) + copy(arts, s.state.Artifacts) result.Artifacts = arts } return result @@ -119,33 +119,33 @@ func (a *SessionRunner[State]) Result() *SessionFlowResult { // maybeSnapshot creates a snapshot if conditions are met (store configured, // callback approves, state changed). Returns the snapshot ID or empty string. -func (a *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { - if a.store == nil { +func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { + if s.store == nil { return "" } - a.mu.RLock() - currentVersion := a.version - currentState := a.copyStateLocked() - a.mu.RUnlock() + s.mu.RLock() + currentVersion := s.version + currentState := s.copyStateLocked() + s.mu.RUnlock() // Skip if state hasn't changed since the last snapshot. This avoids // redundant snapshots, e.g. the invocation-end snapshot after a // single-turn Run where the turn-end snapshot already captured the // same state. - if a.lastSnapshot != nil && currentVersion == a.lastSnapshotVersion { + if s.lastSnapshot != nil && currentVersion == s.lastSnapshotVersion { return "" } - if a.snapshotCallback != nil { + if s.snapshotCallback != nil { var prevState *SessionState[State] - if a.lastSnapshot != nil { - prevState = &a.lastSnapshot.State + if s.lastSnapshot != nil { + prevState = &s.lastSnapshot.State } - if !a.snapshotCallback(ctx, &SnapshotContext[State]{ + if !s.snapshotCallback(ctx, &SnapshotContext[State]{ State: ¤tState, PrevState: prevState, - TurnIndex: a.TurnIndex, + TurnIndex: s.TurnIndex, Event: event, }) { return "" @@ -158,28 +158,28 @@ func (a *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot Event: event, State: currentState, } - if a.lastSnapshot != nil { - snapshot.ParentID = a.lastSnapshot.SnapshotID + if s.lastSnapshot != nil { + snapshot.ParentID = s.lastSnapshot.SnapshotID } - if err := a.store.SaveSnapshot(ctx, snapshot); err != nil { + if err := s.store.SaveSnapshot(ctx, snapshot); err != nil { logger.FromContext(ctx).Error("session flow: failed to save snapshot", "err", err) return "" } // Set snapshotId in last message metadata. - a.mu.Lock() - if msgs := a.state.Messages; len(msgs) > 0 { + s.mu.Lock() + if msgs := s.state.Messages; len(msgs) > 0 { lastMsg := msgs[len(msgs)-1] if lastMsg.Metadata == nil { lastMsg.Metadata = make(map[string]any) } lastMsg.Metadata["snapshotId"] = snapshot.SnapshotID } - a.mu.Unlock() + s.mu.Unlock() - a.lastSnapshot = snapshot - a.lastSnapshotVersion = currentVersion + s.lastSnapshot = snapshot + s.lastSnapshotVersion = currentVersion return snapshot.SnapshotID } From a5b97b02aaa2014e543e46a84f07155986e3d032 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 11 May 2026 17:38:45 -0500 Subject: [PATCH 061/141] feat(go): add background session flows via `Detach` (#5193) --- genkit-tools/common/src/types/agent.ts | 176 ++- genkit-tools/genkit-schema.json | 172 ++- go/ai/exp/gen.go | 44 +- go/ai/exp/option.go | 40 +- go/ai/exp/session.go | 464 ++++++- go/ai/exp/session_flow.go | 1313 +++++++++++++++----- go/ai/exp/session_flow_test.go | 1564 ++++++++++++++++++++++-- go/core/action.go | 15 +- go/core/api/action.go | 2 + go/core/flow.go | 14 + go/core/schemas.config | 78 +- go/genkit/gen.go | 4 +- go/genkit/reflection_test.go | 33 +- 13 files changed, 3491 insertions(+), 428 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 8835fc140f..999cccadcd 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -33,10 +33,42 @@ export type Artifact = z.infer; /** * Zod schema for snapshot event. + * + * - `turnEnd`: snapshot was triggered at the end of a turn. + * - `invocationEnd`: snapshot was triggered at the end of the invocation. + * - `detach`: snapshot was created when the client detached the invocation + * and the flow continues in the background. Initially written with + * `pending` status (and empty state) and rewritten with a terminal + * status and the final cumulative state once the background work + * finishes. */ -export const SnapshotEventSchema = z.enum(['turnEnd', 'invocationEnd']); +export const SnapshotEventSchema = z.enum([ + 'turnEnd', + 'invocationEnd', + 'detach', +]); export type SnapshotEvent = z.infer; +/** + * Zod schema for a snapshot's lifecycle status. + * + * - `pending`: a detached invocation is still processing the queued inputs. + * The snapshot's state is empty until the flow exits, at which point it + * is rewritten with the cumulative final state and a terminal status. + * - `complete`: the snapshot captures a settled state. + * - `canceled`: the snapshot's invocation was aborted via the + * `abortSnapshot` companion action while detached. + * - `error`: the invocation terminated with an error. The snapshot's `error` + * field describes the failure and resume is rejected with that same error. + */ +export const SnapshotStatusSchema = z.enum([ + 'pending', + 'complete', + 'canceled', + 'error', +]); +export type SnapshotStatus = z.infer; + /** * Zod schema for session state. */ @@ -56,6 +88,16 @@ export type SessionState = z.infer; * Zod schema for session flow input (per-turn). */ export const SessionFlowInputSchema = z.object({ + /** + * Detach signals that the client wishes to disconnect after this input is + * accepted. The server writes a single pending snapshot (with empty + * state), returns SessionFlowOutput with that snapshot ID, and continues + * processing any already-buffered inputs in a background context. + * Streamed chunks emitted after detach are not forwarded over the wire; + * only the final cumulative state is captured when the snapshot is + * finalized (or the snapshot is aborted via `abortSnapshot`). + */ + detach: z.boolean().optional(), /** User's input messages for this turn. */ messages: z.array(MessageSchema).optional(), /** Tool request parts to re-execute interrupted tools. */ @@ -111,7 +153,8 @@ export type SessionFlowOutput = z.infer; export const TurnEndSchema = z.object({ /** * ID of the snapshot persisted at the end of this turn. Empty if no - * snapshot was created (callback returned false or no store configured). + * snapshot was created (callback returned false, no store configured, or + * snapshots were suspended after detach). */ snapshotId: z.string().optional(), }); @@ -137,3 +180,132 @@ export const SessionFlowStreamChunkSchema = z.object({ export type SessionFlowStreamChunk = z.infer< typeof SessionFlowStreamChunkSchema >; + +/** + * Zod schema for the metadata projection of a session snapshot. It exists + * so callers can identify a snapshot and check its lifecycle status without + * paying for a full state read. + */ +export const SnapshotMetadataSchema = z.object({ + /** Unique identifier for this snapshot (UUID). */ + snapshotId: z.string(), + /** ID of the previous snapshot in this timeline. */ + parentId: z.string().optional(), + /** When the snapshot was first written (RFC 3339). */ + createdAt: z.string(), + /** When the snapshot was last written (RFC 3339). */ + updatedAt: z.string().optional(), + /** What triggered this snapshot. */ + event: SnapshotEventSchema, + /** Lifecycle state of this snapshot. Empty is treated as `complete`. */ + status: SnapshotStatusSchema.optional(), + /** Failure message for a snapshot in `error` status. */ + error: z.string().optional(), +}); +export type SnapshotMetadata = z.infer; + +/** + * Zod schema for a persisted point-in-time capture of session state. + */ +export const SessionSnapshotSchema = SnapshotMetadataSchema.extend({ + /** + * Conversation state. Empty on a pending snapshot (the live state is + * not yet committed; the background invocation is still processing + * queued inputs); populated on terminal snapshots with the cumulative + * final state. + */ + state: SessionStateSchema, +}); +export type SessionSnapshot = z.infer; + +/** + * Zod schema for the input of a session flow's `getSnapshot` companion + * action. The action is registered at `{flowName}/getSnapshot` when the + * flow is defined. + */ +export const GetSnapshotRequestSchema = z.object({ + /** Identifies the snapshot to fetch. */ + snapshotId: z.string(), +}); +export type GetSnapshotRequest = z.infer; + +/** + * Zod schema for the output of the `getSnapshot` companion action. It is a + * client-facing view of the stored snapshot: identifying metadata plus the + * session state, with `WithSnapshotTransform` applied if configured. + */ +export const GetSnapshotResponseSchema = z.object({ + /** Echoes the requested snapshot ID. */ + snapshotId: z.string(), + /** When the snapshot record was first written (RFC 3339). */ + createdAt: z.string().optional(), + /** When the snapshot record was last written (RFC 3339). */ + updatedAt: z.string().optional(), + /** Lifecycle state of the snapshot. */ + status: SnapshotStatusSchema.optional(), + /** Populated when status is `error`. */ + error: z.string().optional(), + /** + * Session state captured by the snapshot, after any configured transform. + * Empty when status is `pending` or `error`. + */ + state: SessionStateSchema.optional(), +}); +export type GetSnapshotResponse = z.infer; + +/** + * Zod schema for the input of the `abortSnapshot` companion action. + */ +export const AbortSnapshotRequestSchema = z.object({ + /** Identifies the snapshot whose invocation should be aborted. */ + snapshotId: z.string(), +}); +export type AbortSnapshotRequest = z.infer; + +/** + * Zod schema for the output of the `abortSnapshot` companion action. + */ +export const AbortSnapshotResponseSchema = z.object({ + /** Echoes the requested snapshot ID. */ + snapshotId: z.string(), + /** + * Snapshot's status after the abort attempt. For a pending snapshot + * this is `canceled`. For an already-terminal snapshot this is the + * existing terminal status (the abort is a no-op). + */ + status: SnapshotStatusSchema.optional(), +}); +export type AbortSnapshotResponse = z.infer< + typeof AbortSnapshotResponseSchema +>; + +/** + * Who owns session state for an agent. + * + * - `server`: a session store is configured and snapshots are persisted + * server-side. + * - `client`: no store; state flows through the agent's invocation init + * and output payloads. + */ +export const AgentMetadataStateManagementSchema = z.enum(['server', 'client']); +export type AgentMetadataStateManagement = z.infer< + typeof AgentMetadataStateManagementSchema +>; + +/** + * Zod schema for the agent capability metadata placed under + * `metadata.agent` on a session flow's action descriptor. Lets the Dev + * UI and other reflective callers render the right surface (e.g. hide + * the Abort button when the configured store doesn't support it) + * without round-tripping through the reflection API. + */ +export const AgentMetadataSchema = z.object({ + /** Who owns session state for this agent. */ + stateManagement: AgentMetadataStateManagementSchema, + /** + * Whether the agent's invocations can be aborted. True only when the + * configured store implements the abort lifecycle. + */ + abortable: z.boolean(), +}); +export type AgentMetadata = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index b25f5c12d6..607237e3a6 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -1,6 +1,56 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "$defs": { + "AbortSnapshotRequest": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + } + }, + "required": [ + "snapshotId" + ], + "additionalProperties": false + }, + "AbortSnapshotResponse": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "status": { + "$ref": "#/$defs/SnapshotStatus" + } + }, + "required": [ + "snapshotId" + ], + "additionalProperties": false + }, + "AgentMetadata": { + "type": "object", + "properties": { + "stateManagement": { + "$ref": "#/$defs/AgentMetadataStateManagement" + }, + "abortable": { + "type": "boolean" + } + }, + "required": [ + "stateManagement", + "abortable" + ], + "additionalProperties": false + }, + "AgentMetadataStateManagement": { + "type": "string", + "enum": [ + "server", + "client" + ] + }, "Artifact": { "type": "object", "properties": { @@ -23,6 +73,45 @@ ], "additionalProperties": false }, + "GetSnapshotRequest": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + } + }, + "required": [ + "snapshotId" + ], + "additionalProperties": false + }, + "GetSnapshotResponse": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "createdAt": { + "type": "string" + }, + "updatedAt": { + "type": "string" + }, + "status": { + "$ref": "#/$defs/SnapshotStatus" + }, + "error": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + } + }, + "required": [ + "snapshotId" + ], + "additionalProperties": false + }, "SessionFlowInit": { "type": "object", "properties": { @@ -38,6 +127,9 @@ "SessionFlowInput": { "type": "object", "properties": { + "detach": { + "type": "boolean" + }, "messages": { "type": "array", "items": { @@ -105,6 +197,42 @@ }, "additionalProperties": false }, + "SessionSnapshot": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "parentId": { + "type": "string" + }, + "createdAt": { + "type": "string" + }, + "updatedAt": { + "type": "string" + }, + "event": { + "$ref": "#/$defs/SnapshotEvent" + }, + "status": { + "$ref": "#/$defs/SnapshotStatus" + }, + "error": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + } + }, + "required": [ + "snapshotId", + "createdAt", + "event", + "state" + ], + "additionalProperties": false + }, "SessionState": { "type": "object", "properties": { @@ -129,7 +257,49 @@ "type": "string", "enum": [ "turnEnd", - "invocationEnd" + "invocationEnd", + "detach" + ] + }, + "SnapshotMetadata": { + "type": "object", + "properties": { + "snapshotId": { + "$ref": "#/$defs/SessionSnapshot/properties/snapshotId" + }, + "parentId": { + "$ref": "#/$defs/SessionSnapshot/properties/parentId" + }, + "createdAt": { + "$ref": "#/$defs/SessionSnapshot/properties/createdAt" + }, + "updatedAt": { + "$ref": "#/$defs/SessionSnapshot/properties/updatedAt" + }, + "event": { + "$ref": "#/$defs/SnapshotEvent" + }, + "status": { + "$ref": "#/$defs/SessionSnapshot/properties/status" + }, + "error": { + "$ref": "#/$defs/SessionSnapshot/properties/error" + } + }, + "required": [ + "snapshotId", + "createdAt", + "event" + ], + "additionalProperties": false + }, + "SnapshotStatus": { + "type": "string", + "enum": [ + "pending", + "complete", + "canceled", + "error" ] }, "TurnEnd": { diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index f2344605cc..1b28cdcd74 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -22,6 +22,35 @@ import ( "github.com/firebase/genkit/go/ai" ) +// AgentMetadata is the value placed under metadata["agent"] on a session +// flow's action descriptor. It exposes capability information so the Dev +// UI and other reflective callers can render the right surface (e.g. +// hide the Abort button when the configured store doesn't support it) +// without round-tripping through the reflection API. +type AgentMetadata struct { + // Abortable reports whether the agent's invocations can be aborted + // (true when the store implements [SnapshotAborter]). + Abortable bool `json:"abortable,omitempty"` + // StateManagement reports who owns session state. + StateManagement AgentMetadataStateManagement `json:"stateManagement,omitempty"` +} + +// AgentMetadataStateManagement enumerates who owns session state for an +// agent: "server" (a [SessionStore] is configured and snapshots are +// persisted server-side) or "client" (no store; state flows through +// invocation init / output). +type AgentMetadataStateManagement string + +const ( + // AgentMetadataStateManagementServer indicates the agent is wired with + // a [SessionStore] and persists snapshots server-side. + AgentMetadataStateManagementServer AgentMetadataStateManagement = "server" + // AgentMetadataStateManagementClient indicates the agent has no store; + // session state is client-managed and round-trips through invocation + // init and output. + AgentMetadataStateManagementClient AgentMetadataStateManagement = "client" +) + // Artifact represents a named collection of parts produced during a session. // Examples: generated files, images, code snippets, diagrams, etc. type Artifact struct { @@ -46,6 +75,14 @@ type SessionFlowInit[State any] struct { // SessionFlowInput is the input sent to an session flow during a conversation turn. type SessionFlowInput struct { + // Detach signals that the client wishes to disconnect after this input is + // accepted. The server writes a single pending snapshot (with empty + // state), returns [SessionFlowOutput] with that snapshot ID, and + // continues processing any already-buffered inputs in a background + // context. The pending snapshot is finalized with the cumulative final + // state once all queued inputs are processed (or the snapshot is + // cancelled via cancelSnapshot). + Detach bool `json:"detach,omitempty"` // Messages contains the user's input for this turn. Messages []*ai.Message `json:"messages,omitempty"` // ToolRestarts contains tool request parts to re-execute interrupted tools. @@ -119,6 +156,11 @@ const ( SnapshotEventTurnEnd SnapshotEvent = "turnEnd" // InvocationEnd indicates the snapshot was triggered at the end of the invocation. SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" + // Detach indicates the snapshot was created when the client detached the + // invocation and the flow continues in the background. The snapshot is + // initially written with [SnapshotStatusPending] and rewritten with a + // terminal status once the background work finishes. + SnapshotEventDetach SnapshotEvent = "detach" ) // TurnEnd groups the signals emitted when a session flow turn finishes. @@ -127,6 +169,6 @@ const ( type TurnEnd struct { // SnapshotID is the ID of the snapshot persisted at the end of this turn. // Empty if no snapshot was created (callback returned false or no store - // configured). + // configured, or snapshots were suspended after detach). SnapshotID string `json:"snapshotId,omitempty"` } diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index c22e63956f..630d58edaa 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -28,9 +28,24 @@ type SessionFlowOption[State any] interface { applySessionFlow(*sessionFlowOptions[State]) error } +// StateTransform rewrites session state on its way out to a client. It +// is applied to the State returned by the getSnapshot companion action +// and to [SessionFlowOutput.State] when state is client-managed (no +// store). It is not applied to state persisted in the store or to +// state passed to the user flow function. +// +// ctx is the request or invocation context: cancellation, deadlines, +// and context-scoped values (e.g. the caller's identity for RBAC-aware +// redaction) flow through here. +// +// The state input is a deep copy owned by the caller; the transform +// may mutate and return it, or return a freshly-constructed value. +type StateTransform[State any] = func(ctx context.Context, state SessionState[State]) SessionState[State] + type sessionFlowOptions[State any] struct { - store SessionStore[State] - callback SnapshotCallback[State] + store SessionStore[State] + callback SnapshotCallback[State] + transform StateTransform[State] } func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error { @@ -46,10 +61,19 @@ func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[St } opts.callback = o.callback } + if o.transform != nil { + if opts.transform != nil { + return errors.New("cannot set state transform more than once (WithStateTransform)") + } + opts.transform = o.transform + } return nil } -// WithSessionStore sets the store for persisting snapshots. +// WithSessionStore sets the store for persisting snapshots. The store must +// implement [SnapshotReader] and [SnapshotWriter] at minimum. Detach +// support also requires [SnapshotAborter]; detach attempts on a store +// that lacks that interface are rejected at runtime. func WithSessionStore[State any](store SessionStore[State]) SessionFlowOption[State] { return &sessionFlowOptions[State]{store: store} } @@ -74,6 +98,16 @@ func WithSnapshotOn[State any](events ...SnapshotEvent) SessionFlowOption[State] }) } +// WithStateTransform registers a transform applied to session state on +// its way out to a client via the getSnapshot companion action or via +// [SessionFlowOutput.State] when state is client-managed. Typical use +// is PII redaction or stripping secrets. The transform is not applied +// to state persisted in the store or to state passed to the user flow +// function. +func WithStateTransform[State any](transform StateTransform[State]) SessionFlowOption[State] { + return &sessionFlowOptions[State]{transform: transform} +} + // --- InvocationOption --- // InvocationOption configures an session flow invocation (StreamBidi, Run, or RunText). diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index e65d33d426..4b57282efa 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -20,15 +20,48 @@ import ( "context" "encoding/json" "fmt" + "slices" "sync" "time" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/base" + "github.com/google/uuid" ) // --- Snapshot --- +// SnapshotStatus describes the lifecycle state of a snapshot. Snapshots +// written for synchronous turns or invocations are always [SnapshotStatusComplete] +// (an empty value is also treated as complete for backwards compatibility). +// +// When a client sets [SessionFlowInput.Detach], the server writes a single +// snapshot with [SnapshotStatusPending] (and empty state) and returns its +// ID immediately. Background processing then either rewrites that snapshot +// with the cumulative final state and [SnapshotStatusComplete] / +// [SnapshotStatusError] when the flow finishes, or with +// [SnapshotStatusCanceled] if the client called abortSnapshot in the +// meantime. +type SnapshotStatus string + +const ( + // SnapshotStatusPending indicates a detached invocation is still + // processing the queued inputs. The snapshot will be rewritten with a + // terminal status once the flow exits. + SnapshotStatusPending SnapshotStatus = "pending" + // SnapshotStatusComplete indicates the snapshot captures a settled state. + SnapshotStatusComplete SnapshotStatus = "complete" + // SnapshotStatusCanceled indicates the snapshot's invocation was + // aborted via the abortSnapshot companion action while detached. + SnapshotStatusCanceled SnapshotStatus = "canceled" + // SnapshotStatusError indicates the invocation terminated with an error. + // The snapshot's Error field describes the failure and resume is + // rejected with that same error. + SnapshotStatusError SnapshotStatus = "error" +) + // SessionSnapshot is a persisted point-in-time capture of session state. type SessionSnapshot[State any] struct { // SnapshotID is the unique identifier for this snapshot (UUID). @@ -37,9 +70,22 @@ type SessionSnapshot[State any] struct { ParentID string `json:"parentId,omitempty"` // CreatedAt is when the snapshot was created. CreatedAt time.Time `json:"createdAt"` + // UpdatedAt is when the snapshot was last written. For pending snapshots + // it equals CreatedAt; once the snapshot is finalized it reflects the + // terminal write. + UpdatedAt time.Time `json:"updatedAt,omitempty"` // Event is what triggered this snapshot. Event SnapshotEvent `json:"event"` - // State is the actual conversation state. + // Status is the lifecycle state of this snapshot. Empty is treated as + // [SnapshotStatusComplete] for backwards compatibility. + Status SnapshotStatus `json:"status,omitempty"` + // Error is the failure message for a snapshot in [SnapshotStatusError]. + // Empty otherwise. + Error string `json:"error,omitempty"` + // State is the actual conversation state. Empty on a pending snapshot + // (the live state is not yet committed; the background invocation is + // still processing queued inputs); populated on terminal snapshots + // with the cumulative final state. State SessionState[State] `json:"state"` } @@ -59,26 +105,152 @@ type SnapshotContext[State any] struct { // If not provided and a store is configured, snapshots are always created. type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool +// applyTransform returns the result of applying t to *state, or state +// unchanged if t is nil. A nil state is returned as-is. +func applyTransform[State any](ctx context.Context, t StateTransform[State], state *SessionState[State]) *SessionState[State] { + if t == nil || state == nil { + return state + } + transformed := t(ctx, *state) + return &transformed +} + // --- Session store --- -// SessionStore persists and retrieves snapshots. -type SessionStore[State any] interface { +// SnapshotMetadata is the metadata-only projection of a [SessionSnapshot]: +// identifying fields, lifecycle timestamps, and status. Returned by store +// operations that surface a snapshot's lifecycle state without paying for +// a full state read. +type SnapshotMetadata struct { + // SnapshotID is the unique identifier for this snapshot. + SnapshotID string `json:"snapshotId"` + // ParentID is the ID of the previous snapshot in this timeline. + ParentID string `json:"parentId,omitempty"` + // CreatedAt is when the snapshot was first written. + CreatedAt time.Time `json:"createdAt"` + // UpdatedAt is when the snapshot was last written. + UpdatedAt time.Time `json:"updatedAt,omitempty"` + // Event is what triggered this snapshot. + Event SnapshotEvent `json:"event"` + // Status is the lifecycle state of this snapshot. + Status SnapshotStatus `json:"status,omitempty"` + // Error is the failure message for a snapshot in [SnapshotStatusError]. + Error string `json:"error,omitempty"` +} + +// SnapshotReader retrieves snapshots. The minimum any session store must +// implement to be used with [WithSessionStore]. +type SnapshotReader[State any] interface { // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) - // SaveSnapshot persists a snapshot. - SaveSnapshot(ctx context.Context, snapshot *SessionSnapshot[State]) error } -// InMemorySessionStore provides a thread-safe in-memory snapshot store. +// SnapshotWriter persists snapshots. The minimum any session store must +// implement to be used with [WithSessionStore]. +type SnapshotWriter[State any] interface { + // SaveSnapshot atomically reads the snapshot at id (if any), applies + // fn, and persists the result. The store owns identity and + // lifecycle-timestamp fields: + // + // - SnapshotID: if id is empty, the store generates a fresh ID; + // otherwise the store uses id (any SnapshotID populated by fn is + // overridden). + // - CreatedAt: stamped to the wall clock on first write; preserved + // from the existing row on update. + // - UpdatedAt: stamped to the wall clock on every commit. + // - Status: if the snapshot returned by fn has Status="", it is + // defaulted to [SnapshotStatusComplete] (the common case for + // synchronous turn-end and invocation-end writes). Callers + // writing a pending row must set Status explicitly. + // + // fn receives the existing snapshot (or nil if id is empty or the + // row does not exist) and returns the snapshot to commit, or + // (nil, nil) to skip the write without changing the row. + // + // Under contention, stores that use optimistic concurrency or + // transaction retries may call fn multiple times. fn must therefore + // be a pure function of its input: no side effects (channel sends, + // logging, external I/O) inside fn. + // + // Returns the snapshot as persisted (with the store-owned fields + // populated), or nil if fn declined to write. + SaveSnapshot( + ctx context.Context, + id string, + fn func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error), + ) (*SessionSnapshot[State], error) +} + +// SnapshotAborter is the optional capability layered on [SessionStore] +// that lets a session flow's invocations be aborted. It bundles the two +// methods that must be implemented together for the abort lifecycle to +// function: +// +// - [SnapshotAborter.AbortSnapshot] flips a pending snapshot's status +// to canceled (typically called by the abortSnapshot companion +// action or directly by a Go caller holding the store). +// +// - [SnapshotAborter.OnSnapshotStatusChange] lets the session flow +// runtime observe the flip without polling, so it can promptly +// cancel the work context. +// +// They are bundled because neither is useful alone: flipping status +// with no observer means the running fn never learns it was aborted; +// observing without a way to trigger the flip means no abort can +// happen. Splitting them into separate interfaces made the +// "implemented one, not the other" footgun too easy to hit. +type SnapshotAborter interface { + // AbortSnapshot atomically transitions a snapshot from + // [SnapshotStatusPending] to [SnapshotStatusCanceled] and returns the + // resulting metadata. If the snapshot is in any other status the + // operation is a no-op and the existing metadata is returned. Returns + // nil if the snapshot is not found. + // + // Implementations must perform the read-and-write atomically (e.g., a + // transaction or a compare-and-swap). The session flow's abortSnapshot + // action and finalizer rely on this to avoid a pending row being + // clobbered by a racing terminal write. + AbortSnapshot(ctx context.Context, snapshotID string) (*SnapshotMetadata, error) + + // OnSnapshotStatusChange returns a channel that yields the snapshot's + // status whenever it changes. The first value (if any) reflects the + // status at subscription time. The channel is closed when ctx is + // cancelled. If the snapshot does not exist when the subscription is + // established, the channel is closed without yielding a value. + // + // Implementations may push changes from a transaction log, a CDC + // feed, or fall back to polling internally; the contract just spares + // callers the choice. + OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan SnapshotStatus +} + +// SessionStore is the minimum store interface required by +// [WithSessionStore]. The abort lifecycle is layered as the optional +// [SnapshotAborter] capability and checked at runtime: a store wired +// into a flow that intends to support detach must also implement +// [SnapshotAborter], or the runtime will reject detach attempts. +type SessionStore[State any] interface { + SnapshotReader[State] + SnapshotWriter[State] +} + +// InMemorySessionStore provides a thread-safe in-memory snapshot store. It +// implements the full set of optional store interfaces (reader, writer, +// aborter, status subscriber). type InMemorySessionStore[State any] struct { - snapshots map[string]*SessionSnapshot[State] + // mu is RWMutex so GetSnapshot (which JSON-marshals while holding the + // lock) can run concurrently with other readers. All writers (Save, + // Abort, OnSnapshotStatusChange, removeSub) take the full Lock(). mu sync.RWMutex + snapshots map[string]*SessionSnapshot[State] + subs map[string][]chan SnapshotStatus } // NewInMemorySessionStore creates a new in-memory snapshot store. func NewInMemorySessionStore[State any]() *InMemorySessionStore[State] { return &InMemorySessionStore[State]{ snapshots: make(map[string]*SessionSnapshot[State]), + subs: make(map[string][]chan SnapshotStatus), } } @@ -86,30 +258,155 @@ func NewInMemorySessionStore[State any]() *InMemorySessionStore[State] { func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { s.mu.RLock() defer s.mu.RUnlock() + snap, ok := s.snapshots[snapshotID] + if !ok { + return nil, nil + } + return copySnapshot(snap) +} + +// AbortSnapshot atomically flips a pending snapshot to canceled. If the +// snapshot is already terminal the existing metadata is returned unchanged. +// Returns nil if the snapshot is not found. +func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (*SnapshotMetadata, error) { + s.mu.Lock() + defer s.mu.Unlock() + snap, ok := s.snapshots[snapshotID] + if !ok { + return nil, nil + } + if snap.Status == SnapshotStatusPending { + snap.Status = SnapshotStatusCanceled + snap.UpdatedAt = time.Now() + s.notifyLocked(snapshotID, snap.Status) + } + return snapshotMetadata(snap), nil +} - snap, exists := s.snapshots[snapshotID] - if !exists { +// SaveSnapshot atomically reads, applies fn, and persists. See the +// [SnapshotWriter] interface for the full contract; this implementation +// satisfies it by holding s.mu for the entire read-modify-write so fn +// is called exactly once per SaveSnapshot call. +func (s *InMemorySessionStore[State]) SaveSnapshot( + _ context.Context, + id string, + fn func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error), +) (*SessionSnapshot[State], error) { + s.mu.Lock() + defer s.mu.Unlock() + + if id == "" { + id = uuid.New().String() + } + + var existing *SessionSnapshot[State] + if stored, ok := s.snapshots[id]; ok { + copied, err := copySnapshot(stored) + if err != nil { + return nil, err + } + existing = copied + } + + next, err := fn(existing) + if err != nil { + return nil, err + } + if next == nil { return nil, nil } - copied, err := copySnapshot(snap) + next.SnapshotID = id + now := time.Now() + if existing != nil { + next.CreatedAt = existing.CreatedAt + } else { + next.CreatedAt = now + } + next.UpdatedAt = now + if next.Status == "" { + next.Status = SnapshotStatusComplete + } + + copied, err := copySnapshot(next) if err != nil { return nil, err } - return copied, nil + s.snapshots[id] = copied + if existing == nil || existing.Status != next.Status { + s.notifyLocked(id, next.Status) + } + // Return next (the freshly-allocated struct from fn) rather than + // copied: copied is the pointer the store retains, so returning it + // would alias the caller's view with the stored row and let future + // in-place mutations (e.g. AbortSnapshot updating UpdatedAt) leak + // through. + return next, nil } -// SaveSnapshot persists a snapshot. -func (s *InMemorySessionStore[State]) SaveSnapshot(_ context.Context, snapshot *SessionSnapshot[State]) error { +// OnSnapshotStatusChange subscribes to status changes for a snapshot. The +// returned channel yields the current status (if any) and any subsequent +// changes, until ctx is cancelled. +func (s *InMemorySessionStore[State]) OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan SnapshotStatus { + ch := make(chan SnapshotStatus, 1) + + s.mu.Lock() + snap, ok := s.snapshots[snapshotID] + if !ok { + s.mu.Unlock() + close(ch) + return ch + } + ch <- snap.Status + s.subs[snapshotID] = append(s.subs[snapshotID], ch) + s.mu.Unlock() + + context.AfterFunc(ctx, func() { s.removeSub(snapshotID, ch) }) + return ch +} + +// removeSub detaches a subscriber and closes its channel. +func (s *InMemorySessionStore[State]) removeSub(snapshotID string, ch chan SnapshotStatus) { s.mu.Lock() defer s.mu.Unlock() + subs := s.subs[snapshotID] + i := slices.Index(subs, ch) + if i < 0 { + return + } + subs = slices.Delete(subs, i, i+1) + if len(subs) == 0 { + delete(s.subs, snapshotID) + } else { + s.subs[snapshotID] = subs + } + close(ch) +} - copied, err := copySnapshot(snapshot) - if err != nil { - return err +// notifyLocked publishes status to all live subscribers of snapshotID. +// Caller must hold s.mu. Sends are best-effort: a slow subscriber may miss +// intermediate values, but the store guarantees the latest value visible +// to the subscription is the one persisted at notify time. +func (s *InMemorySessionStore[State]) notifyLocked(snapshotID string, status SnapshotStatus) { + for _, ch := range s.subs[snapshotID] { + select { + case ch <- status: + default: + } + } +} + +// snapshotMetadata projects the metadata fields of a snapshot. +func snapshotMetadata[State any](snap *SessionSnapshot[State]) *SnapshotMetadata { + return &SnapshotMetadata{ + SnapshotID: snap.SnapshotID, + ParentID: snap.ParentID, + CreatedAt: snap.CreatedAt, + UpdatedAt: snap.UpdatedAt, + Event: snap.Event, + Status: snap.Status, + Error: snap.Error, } - s.snapshots[copied.SnapshotID] = copied - return nil } // copySnapshot creates a deep copy of a snapshot using JSON marshaling. @@ -128,6 +425,137 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta return &copied, nil } +// --- Snapshot companion actions --- + +// GetSnapshotRequest is the input for a session flow's getSnapshot companion +// action. The action is registered at `{flowName}/getSnapshot` when the flow +// is defined and is intended for Dev UI and client-side reconnect flows. +type GetSnapshotRequest struct { + // SnapshotID identifies the snapshot to fetch. + SnapshotID string `json:"snapshotId"` +} + +// GetSnapshotResponse is the output of the getSnapshot companion action. It +// is a client-facing view of the stored snapshot: identifying metadata plus +// the session state, with [WithStateTransform] applied if configured. +// +// Unlike the raw [SessionSnapshot], this response intentionally omits +// internal fields (parent ID, event) and does not leak the snapshot +// envelope beyond what callers need to repopulate a UI. +type GetSnapshotResponse[State any] struct { + // SnapshotID echoes the requested snapshot ID. + SnapshotID string `json:"snapshotId"` + // CreatedAt is when the snapshot record was first written. + CreatedAt time.Time `json:"createdAt,omitempty"` + // UpdatedAt is when the snapshot record was last written. Equals + // CreatedAt for snapshots that have not been rewritten. + UpdatedAt time.Time `json:"updatedAt,omitempty"` + // Status is the lifecycle state of the snapshot. See [SnapshotStatus]. + Status SnapshotStatus `json:"status,omitempty"` + // Error is populated when Status is [SnapshotStatusError]. + Error string `json:"error,omitempty"` + // State is the session state captured by the snapshot, after any + // configured transform. Empty when Status is pending or error. + State *SessionState[State] `json:"state,omitempty"` +} + +// AbortSnapshotRequest is the input for the abortSnapshot companion action. +type AbortSnapshotRequest struct { + // SnapshotID identifies the snapshot whose invocation should be aborted. + SnapshotID string `json:"snapshotId"` +} + +// AbortSnapshotResponse is the output of the abortSnapshot companion action. +type AbortSnapshotResponse struct { + // SnapshotID echoes the requested snapshot ID. + SnapshotID string `json:"snapshotId"` + // Status is the snapshot's status after the abort attempt. For a + // pending snapshot this is [SnapshotStatusCanceled]. For an + // already-terminal snapshot this is the existing terminal status (the + // abort is a no-op). + Status SnapshotStatus `json:"status,omitempty"` +} + +// registerSnapshotActions registers the session flow's companion actions: +// +// - The flow's name under [api.ActionTypeAgentSnapshot] — getSnapshot, +// registered whenever a [SessionStore] is configured. The action is +// the remote counterpart to [SessionStore.GetSnapshot] for Dev UI and +// non-Go clients; local Go callers use the store reference directly. +// +// - The flow's name under [api.ActionTypeAgentAbort] — abortSnapshot, +// registered only when the store implements [SnapshotAborter] +// (which bundles both the abort trigger and the status-change +// subscription needed for the runtime to react). Surfacing the +// action only when the capability is present keeps the reflected +// API aligned with what the store can actually do. +func registerSnapshotActions[State any]( + r api.Registry, + flowName string, + store SessionStore[State], + transform StateTransform[State], +) { + core.DefineAction(r, flowName, api.ActionTypeAgentSnapshot, nil, nil, + func(ctx context.Context, req *GetSnapshotRequest) (*GetSnapshotResponse[State], error) { + if store == nil { + return nil, core.NewError(core.FAILED_PRECONDITION, + "getSnapshot: session flow %q has no session store configured", flowName) + } + if req == nil || req.SnapshotID == "" { + return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId is required") + } + snap, err := store.GetSnapshot(ctx, req.SnapshotID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) + } + if snap == nil { + return nil, core.NewError(core.NOT_FOUND, "getSnapshot: snapshot %q not found", req.SnapshotID) + } + + status := snap.Status + if status == "" { + status = SnapshotStatusComplete + } + updatedAt := snap.UpdatedAt + if updatedAt.IsZero() { + updatedAt = snap.CreatedAt + } + + resp := &GetSnapshotResponse[State]{ + SnapshotID: snap.SnapshotID, + CreatedAt: snap.CreatedAt, + UpdatedAt: updatedAt, + Status: status, + Error: snap.Error, + } + if status != SnapshotStatusError && status != SnapshotStatusPending { + resp.State = applyTransform(ctx, transform, &snap.State) + } + return resp, nil + }) + + aborter, ok := store.(SnapshotAborter) + if !ok { + // Store doesn't support the abort lifecycle. Don't surface the + // action. + return + } + core.DefineAction(r, flowName, api.ActionTypeAgentAbort, nil, nil, + func(ctx context.Context, req *AbortSnapshotRequest) (*AbortSnapshotResponse, error) { + if req == nil || req.SnapshotID == "" { + return nil, core.NewError(core.INVALID_ARGUMENT, "abortSnapshot: snapshotId is required") + } + meta, err := aborter.AbortSnapshot(ctx, req.SnapshotID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "abortSnapshot: %v", err) + } + if meta == nil { + return nil, core.NewError(core.NOT_FOUND, "abortSnapshot: snapshot %q not found", req.SnapshotID) + } + return &AbortSnapshotResponse{SnapshotID: meta.SnapshotID, Status: meta.Status}, nil + }) +} + // --- Session --- // Session holds conversation state and provides thread-safe read/write access to messages, diff --git a/go/ai/exp/session_flow.go b/go/ai/exp/session_flow.go index d7c6b8e89a..5f3136b7f3 100644 --- a/go/ai/exp/session_flow.go +++ b/go/ai/exp/session_flow.go @@ -25,7 +25,7 @@ import ( "fmt" "iter" "sync" - "time" + "sync/atomic" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" @@ -33,9 +33,435 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/base" - "github.com/google/uuid" ) +// --- SessionFlow --- + +// SessionFlowFunc is the function signature for session flows. +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +type SessionFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *SessionRunner[State]) (*SessionFlowResult, error) + +// SessionFlow is a bidirectional streaming flow with automatic snapshot management. +type SessionFlow[Stream, State any] struct { + action *core.Action[*SessionFlowInit[State], *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream], *SessionFlowInput] +} + +// DefineSessionFlow creates an SessionFlow with automatic snapshot +// management and registers it. The underlying action is created via +// [core.DefineBidiAction] (rather than [core.DefineBidiFlow]) so the +// agent capability metadata can be set at construction time — actions +// must be immutable once registered. The flow-context wrapping that +// makes [core.Run] work inside fn is preserved via +// [core.WithFlowContext]. +func DefineSessionFlow[Stream, State any]( + r api.Registry, + name string, + fn SessionFlowFunc[Stream, State], + opts ...SessionFlowOption[State], +) *SessionFlow[Stream, State] { + cfg := &sessionFlowOptions[State]{} + for _, opt := range opts { + if err := opt.applySessionFlow(cfg); err != nil { + panic(fmt.Errorf("DefineSessionFlow %q: %w", name, err)) + } + } + + action := core.DefineBidiAction(r, name, api.ActionTypeFlow, + &core.ActionOptions{ + Metadata: map[string]any{"agent": agentMetadataFor(cfg.store)}, + }, + func( + ctx context.Context, + in *SessionFlowInit[State], + inCh <-chan *SessionFlowInput, + outCh chan<- *SessionFlowStreamChunk[Stream], + ) (*SessionFlowOutput[State], error) { + ctx = core.WithFlowContext(ctx, name) + rt, err := newSessionFlowRuntime(ctx, name, cfg, in, inCh, outCh) + if err != nil { + return nil, err + } + return rt.run(ctx, fn) + }) + + registerSnapshotActions(r, name, cfg.store, cfg.transform) + + return &SessionFlow[Stream, State]{action: action} +} + +// agentMetadataFor derives the [AgentMetadata] value attached to the +// session flow's action descriptor under the "agent" key. [AgentMetadata] +// itself is generated from agent.ts; this constructor is hand-written +// because it inspects the configured store's optional capabilities. +func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { + mgmt := AgentMetadataStateManagementClient + abortable := false + if store != nil { + mgmt = AgentMetadataStateManagementServer + _, abortable = store.(SnapshotAborter) + } + return AgentMetadata{ + StateManagement: mgmt, + Abortable: abortable, + } +} + +// --- sessionFlowRuntime --- + +// sessionFlowRuntime owns the per-invocation wiring of a session flow: +// session, runner, output router, input intake, and the goroutine that runs +// the user fn. Its methods implement the three terminal paths the flow can +// take: detach, fn-completion, and client-cancel. +type sessionFlowRuntime[Stream, State any] struct { + name string + cfg *sessionFlowOptions[State] + + session *Session[State] + runner *SessionRunner[State] + router *chunkRouter[Stream, State] + intake *detachIntake + + fnDone chan fnDoneResult[State] +} + +// fnDoneResult carries the user fn's return values across the goroutine +// boundary that runs it. A named type keeps the channel signatures readable. +type fnDoneResult[State any] struct { + result *SessionFlowResult + err error +} + +func newSessionFlowRuntime[Stream, State any]( + ctx context.Context, + name string, + cfg *sessionFlowOptions[State], + in *SessionFlowInit[State], + inCh <-chan *SessionFlowInput, + outCh chan<- *SessionFlowStreamChunk[Stream], +) (*sessionFlowRuntime[Stream, State], error) { + session, parent, err := loadSession(ctx, in, cfg.store) + if err != nil { + return nil, err + } + + rt := &sessionFlowRuntime[Stream, State]{ + name: name, + cfg: cfg, + session: session, + router: startChunkRouter(session, outCh), + intake: startDetachIntake(inCh), + fnDone: make(chan fnDoneResult[State], 1), + } + + rt.runner = &SessionRunner[State]{ + Session: session, + InputCh: rt.intake.out(), + snapshotCallback: cfg.callback, + lastSnapshot: parent, + intake: rt.intake, + } + rt.runner.collectTurnOutput = func() any { return rt.router.collectTurnChunks() } + rt.runner.onEndTurn = rt.emitTurnEnd + + return rt, nil +} + +// emitTurnEnd is called by the runner after each successful turn. It writes +// a turn-end snapshot (if applicable) and forwards the resulting [TurnEnd] +// chunk through the router so clients see it on the output stream. +func (rt *sessionFlowRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { + snapshotID := rt.runner.maybeSnapshot(ctx, SnapshotEventTurnEnd) + rt.router.send() <- &SessionFlowStreamChunk[Stream]{TurnEnd: &TurnEnd{ + SnapshotID: snapshotID, + }} +} + +// run drives the user fn to completion and returns the flow output. +// +// workCtx carries the session and is decoupled from clientCtx: pre-detach a +// watcher mirrors clientCtx so a disconnect cancels the work; on detach the +// watcher exits and the finalizer goroutine owns workCtx until fn returns. +func (rt *sessionFlowRuntime[Stream, State]) run( + clientCtx context.Context, + fn SessionFlowFunc[Stream, State], +) (*SessionFlowOutput[State], error) { + workCtx, cancelWork := context.WithCancel(context.WithoutCancel(clientCtx)) + workCtx = NewSessionContext(workCtx, rt.session) + + var detachOnce sync.Once + detached := make(chan struct{}) + markDetached := func() { detachOnce.Do(func() { close(detached) }) } + defer markDetached() // ensure the watcher exits on every return path + + go func() { + select { + case <-clientCtx.Done(): + cancelWork() + case <-detached: + } + }() + + go func() { + result, err := fn(workCtx, rt.router.responder(), rt.runner) + rt.fnDone <- fnDoneResult[State]{result: result, err: err} + }() + + select { + case <-rt.intake.detachSignal(): + if err := rt.checkDetachCapabilities(); err != nil { + rt.drainAndWait(cancelWork) + return nil, err + } + return rt.handleDetach(clientCtx, workCtx, cancelWork, markDetached) + + case res := <-rt.fnDone: + return rt.handleFnDone(clientCtx, cancelWork, res) + + case <-clientCtx.Done(): + res := rt.drainAndWait(cancelWork) + if res.err != nil { + return nil, res.err + } + return nil, clientCtx.Err() + } +} + +// checkDetachCapabilities reports whether the configured store is capable +// of supporting detach. Detach requires a writable store (to persist the +// pending snapshot) and a [SnapshotAborter] (which bundles both abort +// triggering and status-change subscription so the runtime can react to +// the abort without polling). +func (rt *sessionFlowRuntime[Stream, State]) checkDetachCapabilities() error { + if rt.cfg.store == nil { + return core.NewError(core.FAILED_PRECONDITION, + "session flow %q: detach requires a session store", rt.name) + } + if _, ok := rt.cfg.store.(SnapshotAborter); !ok { + return core.NewError(core.FAILED_PRECONDITION, + "session flow %q: detach requires a session store implementing SnapshotAborter", rt.name) + } + return nil +} + +// drainAndWait performs a synchronous shutdown: cancel work, wait for the +// intake reader/forwarder to finish, drain fnDone, and close the router. +// Returns the fn's result for callers that need to surface its error. +func (rt *sessionFlowRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFunc) fnDoneResult[State] { + cancelWork() + rt.intake.stopAndWait() + res := <-rt.fnDone + rt.router.close() + return res +} + +// handleFnDone is the synchronous-completion path: fn returned before any +// detach signal. Capture an invocation-end snapshot if state advanced past +// the last turn-end snapshot, then assemble the output. +func (rt *sessionFlowRuntime[Stream, State]) handleFnDone( + ctx context.Context, + cancelWork context.CancelFunc, + res fnDoneResult[State], +) (*SessionFlowOutput[State], error) { + cancelWork() + rt.intake.stopAndWait() + rt.router.close() + + if res.err != nil { + return nil, res.err + } + + snapshotID := rt.runner.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + if snapshotID == "" && rt.runner.lastSnapshot != nil { + // State unchanged since the last turn-end snapshot — reuse it so + // the response always carries an ID when a store is configured. + snapshotID = rt.runner.lastSnapshot.SnapshotID + } + + out := &SessionFlowOutput[State]{SnapshotID: snapshotID} + if res.result != nil { + out.Message = res.result.Message + out.Artifacts = res.result.Artifacts + } + if rt.cfg.store == nil { + out.State = applyTransform(ctx, rt.cfg.transform, rt.session.State()) + } + return out, nil +} + +// handleDetach commits the pending snapshot, returns its ID, and spawns the +// status-subscriber and finalizer goroutines that own the rest of the +// invocation. Per-turn snapshots are suspended for the remainder so the +// queued inputs roll into a single finalize rewrite; the chunk router +// stops writing to outCh but keeps applying in-process side effects +// (e.g. artifacts added via Responder.SendArtifact) so user code does +// not have to branch on detach. +func (rt *sessionFlowRuntime[Stream, State]) handleDetach( + clientCtx, workCtx context.Context, + cancelWork context.CancelFunc, + markDetached func(), +) (*SessionFlowOutput[State], error) { + // Stop mirroring clientCtx. From here, only the abort subscription or + // fn completion can cancel workCtx. + markDetached() + + rt.intake.suspend() + + parentID := rt.runner.parentSnapshotID() + + // Detach intends to outlive the client connection. If clientCtx was + // already cancelled (or cancels mid-write), we still want the pending + // row durable so observers can find it later. Decouple this write. + pending, err := rt.cfg.store.SaveSnapshot(context.WithoutCancel(clientCtx), "", + func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + return &SessionSnapshot[State]{ + ParentID: parentID, + Event: SnapshotEventDetach, + Status: SnapshotStatusPending, + }, nil + }) + if err != nil { + rt.drainAndWait(cancelWork) + return nil, core.NewError(core.INTERNAL, + "session flow %q: detach: save pending snapshot: %v", rt.name, err) + } + + // The router can no longer write to outCh once we return; the bidi + // framework closes it shortly after. The router stops writing and + // trashes any further chunks. + rt.router.stopAndWait() + + canceledByUser := &atomic.Bool{} + subCtx, stopSub := context.WithCancel(workCtx) + aborter := rt.cfg.store.(SnapshotAborter) // safe: checkDetachCapabilities ran already + statusCh := aborter.OnSnapshotStatusChange(subCtx, pending.SnapshotID) + go func() { + for status := range statusCh { + if status == SnapshotStatusCanceled { + canceledByUser.Store(true) + cancelWork() + return + } + } + }() + + finalizeCtx := context.WithoutCancel(clientCtx) + go func() { + res := <-rt.fnDone + stopSub() + rt.intake.stopAndWait() + rt.router.close() + rt.finalizePendingSnapshot(finalizeCtx, pending, res.err, canceledByUser.Load()) + cancelWork() + }() + + return &SessionFlowOutput[State]{SnapshotID: pending.SnapshotID}, nil +} + +// finalizePendingSnapshot rewrites the pending snapshot row with the +// terminal state and status. canceledByUser distinguishes a context +// cancellation from abortSnapshot (status=canceled) from an internal +// failure (status=error). The write is funneled through SaveSnapshot +// so the read-and-rewrite is one atomic step: if the row has already +// transitioned to canceled (a late abort racing this finalize), +// SaveSnapshot sees it inside fn and we leave the row untouched. +func (rt *sessionFlowRuntime[Stream, State]) finalizePendingSnapshot( + ctx context.Context, + pending *SessionSnapshot[State], + fnErr error, + canceledByUser bool, +) { + finalState := *rt.session.State() + + _, err := rt.cfg.store.SaveSnapshot(ctx, pending.SnapshotID, + func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + // Late abort wins over the terminal we were about to land. + if existing != nil && existing.Status == SnapshotStatusCanceled { + return nil, nil + } + + status := SnapshotStatusComplete + errMsg := "" + switch { + case canceledByUser: + status = SnapshotStatusCanceled + if fnErr != nil { + errMsg = fnErr.Error() // canceled wins, preserve text + } + case fnErr != nil: + status = SnapshotStatusError + errMsg = fnErr.Error() + } + + return &SessionSnapshot[State]{ + ParentID: pending.ParentID, + Event: SnapshotEventDetach, + Status: status, + Error: errMsg, + State: finalState, + }, nil + }) + if err != nil { + logger.FromContext(ctx).Error("session flow: failed to finalize pending snapshot", + "snapshotId", pending.SnapshotID, "err", err) + } +} + +// loadSession constructs a Session from the invocation's init payload, +// loading from the store when a snapshot ID is provided. Returns the +// snapshot too so the runtime can chain ParentID off it. +func loadSession[State any]( + ctx context.Context, + init *SessionFlowInit[State], + store SessionStore[State], +) (*Session[State], *SessionSnapshot[State], error) { + s := &Session[State]{store: store} + if init == nil { + return s, nil, nil + } + + if init.SnapshotID != "" && init.State != nil { + return nil, nil, core.NewError(core.INVALID_ARGUMENT, "snapshot ID and state are mutually exclusive") + } + + if init.SnapshotID == "" { + if init.State != nil { + s.state = *init.State + } + return s, nil, nil + } + + if store == nil { + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "snapshot ID %q provided but no session store configured", init.SnapshotID) + } + snap, err := store.GetSnapshot(ctx, init.SnapshotID) + if err != nil { + return nil, nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) + } + if snap == nil { + return nil, nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) + } + switch snap.Status { + case SnapshotStatusError: + msg := snap.Error + if msg == "" { + msg = "snapshot recorded an error" + } + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "snapshot %q terminated with error: %s", init.SnapshotID, msg) + case SnapshotStatusPending: + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "snapshot %q is still pending; wait for it to finalize before resuming", init.SnapshotID) + case SnapshotStatusCanceled: + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "snapshot %q was canceled", init.SnapshotID) + } + s.state = snap.State + return s, snap, nil +} + // --- SessionRunner --- // SessionRunner extends Session with session-flow-specific functionality: @@ -59,6 +485,22 @@ type SessionRunner[State any] struct { lastSnapshot *SessionSnapshot[State] lastSnapshotVersion uint64 collectTurnOutput func() any + + // intake is the source of truth for in-flight tracking, queue state, + // and suspended state. The runner consults it via beginTurnEnd (in + // maybeSnapshot) so per-turn snapshot writes and detach captures + // cannot race over the same input. + intake *detachIntake +} + +// parentSnapshotID returns the ID of the most recent snapshot in this +// invocation (used to chain new snapshots via ParentID), or "" if no +// snapshot has been written yet. +func (s *SessionRunner[State]) parentSnapshotID() string { + if s.lastSnapshot == nil { + return "" + } + return s.lastSnapshot.SnapshotID } // Run loops over the input channel, calling fn for each turn. Each turn is @@ -72,18 +514,14 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont Type: "flowStep", Subtype: "flowStep", } - _, err := tracing.RunInNewSpan(ctx, spanMeta, input, func(ctx context.Context, input *SessionFlowInput) (any, error) { s.AddMessages(input.Messages...) - if err := fn(ctx, input); err != nil { return nil, err } - s.onEndTurn(ctx) s.TurnIndex++ - if s.collectTurnOutput != nil { return s.collectTurnOutput(), nil } @@ -118,8 +556,21 @@ func (s *SessionRunner[State]) Result() *SessionFlowResult { } // maybeSnapshot creates a snapshot if conditions are met (store configured, -// callback approves, state changed). Returns the snapshot ID or empty string. +// callback approves, state changed, detach has not suspended snapshots). +// Returns the snapshot ID or empty string. +// +// For turn-end events, the runner asks the intake whether snapshots +// have been suspended (i.e. detach has landed). If so, the runner skips +// the turn-end snapshot — the pending row already captures the +// invocation and a single finalize rewrite will record the cumulative +// state once the queued inputs drain. func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { + if event == SnapshotEventTurnEnd && s.intake != nil { + if suspended := s.intake.beginTurnEnd(); suspended { + return "" + } + } + if s.store == nil { return "" } @@ -152,36 +603,25 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot } } - snapshot := &SessionSnapshot[State]{ - SnapshotID: uuid.New().String(), - CreatedAt: time.Now(), - Event: event, - State: currentState, - } - if s.lastSnapshot != nil { - snapshot.ParentID = s.lastSnapshot.SnapshotID - } - - if err := s.store.SaveSnapshot(ctx, snapshot); err != nil { + parentID := s.parentSnapshotID() + + saved, err := s.store.SaveSnapshot(ctx, "", + func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + return &SessionSnapshot[State]{ + ParentID: parentID, + Event: event, + Status: SnapshotStatusComplete, + State: currentState, + }, nil + }) + if err != nil { logger.FromContext(ctx).Error("session flow: failed to save snapshot", "err", err) return "" } - // Set snapshotId in last message metadata. - s.mu.Lock() - if msgs := s.state.Messages; len(msgs) > 0 { - lastMsg := msgs[len(msgs)-1] - if lastMsg.Metadata == nil { - lastMsg.Metadata = make(map[string]any) - } - lastMsg.Metadata["snapshotId"] = snapshot.SnapshotID - } - s.mu.Unlock() - - s.lastSnapshot = snapshot + s.lastSnapshot = saved s.lastSnapshotVersion = currentVersion - - return snapshot.SnapshotID + return saved.SnapshotID } // --- Responder --- @@ -202,251 +642,408 @@ func (r Responder[Stream]) SendStatus(status Stream) { } // SendArtifact sends an artifact to the stream and adds it to the session. -// If an artifact with the same name already exists in the session, it is replaced. +// If an artifact with the same name already exists in the session, it is +// replaced. The session-level side effect happens whether or not detach +// has landed; only the wire forward to the client is suppressed +// post-detach, when there is no longer a client to receive it. func (r Responder[Stream]) SendArtifact(artifact *Artifact) { r <- &SessionFlowStreamChunk[Stream]{Artifact: artifact} } -// --- SessionFlow --- - -// SessionFlowFunc is the function signature for session flows. -// Type parameters: -// - Stream: Type for status updates sent via the responder -// - State: Type for user-defined state in snapshots -type SessionFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *SessionRunner[State]) (*SessionFlowResult, error) - -// SessionFlow is a bidirectional streaming flow with automatic snapshot management. -type SessionFlow[Stream, State any] struct { - flow *core.Flow[*SessionFlowInit[State], *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream], *SessionFlowInput] +// --- chunkRouter --- +// +// chunkRouter owns the intermediate stream channel that all chunks flow +// through on their way to outCh. Every chunk gets the same in-process +// side effects (adding artifacts to the session, accumulating turn +// chunks for span output) regardless of whether detach has landed; the +// wire forward to outCh is the only thing detach suppresses, since the +// bidi framework closes outCh shortly after bidiFn returns. The router +// commits to not writing before we return so that close is safe, and +// keeps draining its input so the user fn never blocks on a responder +// send. + +type chunkRouter[Stream, State any] struct { + in chan *SessionFlowStreamChunk[Stream] + out chan<- *SessionFlowStreamChunk[Stream] + session *Session[State] + + turnMu sync.Mutex + turnChunks []*SessionFlowStreamChunk[Stream] + + done chan struct{} + stopWriting chan struct{} + writerStopped chan struct{} } -// DefineSessionFlow creates an SessionFlow with automatic snapshot management and registers it. -func DefineSessionFlow[Stream, State any]( - r api.Registry, - name string, - fn SessionFlowFunc[Stream, State], - opts ...SessionFlowOption[State], -) *SessionFlow[Stream, State] { - afOpts := &sessionFlowOptions[State]{} - for _, opt := range opts { - if err := opt.applySessionFlow(afOpts); err != nil { - panic(fmt.Errorf("DefineSessionFlow %q: %w", name, err)) - } +func startChunkRouter[Stream, State any]( + session *Session[State], + out chan<- *SessionFlowStreamChunk[Stream], +) *chunkRouter[Stream, State] { + r := &chunkRouter[Stream, State]{ + in: make(chan *SessionFlowStreamChunk[Stream]), + out: out, + session: session, + done: make(chan struct{}), + stopWriting: make(chan struct{}), + writerStopped: make(chan struct{}), } + go r.run() + return r +} - store := afOpts.store - snapshotCallback := afOpts.callback - - flow := core.DefineBidiFlow(r, name, func( - ctx context.Context, - in *SessionFlowInit[State], - inCh <-chan *SessionFlowInput, - outCh chan<- *SessionFlowStreamChunk[Stream], - ) (*SessionFlowOutput[State], error) { - session, snapshot, err := newSessionFromInit(ctx, in, store) - if err != nil { - return nil, err - } - ctx = NewSessionContext(ctx, session) - - agentSess := &SessionRunner[State]{ - Session: session, - snapshotCallback: snapshotCallback, - InputCh: inCh, - lastSnapshot: snapshot, - } - - // Turn output accumulator: collects content chunks per turn for span output. - var ( - turnMu sync.Mutex - turnChunks []*SessionFlowStreamChunk[Stream] - ) +func (r *chunkRouter[Stream, State]) run() { + defer close(r.done) + if !r.forward() { + // r.in closed before detach; nothing left to do. + return + } + close(r.writerStopped) + // Detached: keep applying side effects so the user fn's + // SendArtifact/SendModelChunk calls behave the same way they did + // pre-detach. Only the wire forward to outCh is suppressed. + for chunk := range r.in { + r.applySideEffects(chunk) + } +} - agentSess.collectTurnOutput = func() any { - turnMu.Lock() - defer turnMu.Unlock() - result := turnChunks - turnChunks = nil - return result - } +// applySideEffects records the chunk's effect on session state and turn +// span output. Invoked from both forward (pre-detach) and the post-detach +// drain so a Send call is observably the same in either mode. +func (r *chunkRouter[Stream, State]) applySideEffects(chunk *SessionFlowStreamChunk[Stream]) { + if chunk.Artifact != nil { + r.session.AddArtifacts(chunk.Artifact) + } + if chunk.TurnEnd == nil { + r.turnMu.Lock() + r.turnChunks = append(r.turnChunks, chunk) + r.turnMu.Unlock() + } +} - // Intermediary channel: intercepts artifacts, accumulates turn output, - // and forwards to outCh. - respCh := make(chan *SessionFlowStreamChunk[Stream]) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - for chunk := range respCh { - if chunk.Artifact != nil { - session.AddArtifacts(chunk.Artifact) - } - // Accumulate content chunks (exclude the TurnEnd control signal). - if chunk.TurnEnd == nil { - turnMu.Lock() - turnChunks = append(turnChunks, chunk) - turnMu.Unlock() - } - outCh <- chunk +// forward delivers chunks to outCh and applies side effects until detach +// or r.in closes. Returns true if it stopped because of detach. +func (r *chunkRouter[Stream, State]) forward() bool { + for { + select { + case chunk, ok := <-r.in: + if !ok { + return false } - }() - - // Wire up onEndTurn: triggers snapshot + sends TurnEnd chunk. - // Writes through respCh to preserve ordering with user chunks. - agentSess.onEndTurn = func(turnCtx context.Context) { - snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) - respCh <- &SessionFlowStreamChunk[Stream]{ - TurnEnd: &TurnEnd{SnapshotID: snapshotID}, + r.applySideEffects(chunk) + select { + case r.out <- chunk: + case <-r.stopWriting: + return true } + case <-r.stopWriting: + return true } + } +} - result, fnErr := fn(ctx, Responder[Stream](respCh), agentSess) - close(respCh) - wg.Wait() +// responder returns a [Responder] that sends chunks into the router. +func (r *chunkRouter[Stream, State]) responder() Responder[Stream] { + return Responder[Stream](r.in) +} - if fnErr != nil { - return nil, fnErr - } +// send returns the internal chunk channel for producers other than the user +// flow function (e.g. the runtime's emitTurnEnd). +func (r *chunkRouter[Stream, State]) send() chan<- *SessionFlowStreamChunk[Stream] { + return r.in +} - // Final snapshot at invocation end. If skipped (state unchanged - // since last turn-end snapshot), use the last snapshot's ID so - // the output always reflects the latest snapshot. - snapshotID := agentSess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) - if snapshotID == "" && agentSess.lastSnapshot != nil { - snapshotID = agentSess.lastSnapshot.SnapshotID - } +// collectTurnChunks returns and resets accumulated turn chunks. +func (r *chunkRouter[Stream, State]) collectTurnChunks() []*SessionFlowStreamChunk[Stream] { + r.turnMu.Lock() + defer r.turnMu.Unlock() + result := r.turnChunks + r.turnChunks = nil + return result +} - out := &SessionFlowOutput[State]{ - SnapshotID: snapshotID, - } - if result != nil { - out.Message = result.Message - out.Artifacts = result.Artifacts - } +// stopAndWait tells the router to stop writing to out and blocks until it +// has committed. After it returns, it is safe for the framework to close +// out without risking a write-to-closed-channel panic. +func (r *chunkRouter[Stream, State]) stopAndWait() { + close(r.stopWriting) + <-r.writerStopped +} - // Only include full state when client-managed (no store). - if store == nil { - out.State = session.State() - } +// close signals end-of-input and waits for the router to drain. +func (r *chunkRouter[Stream, State]) close() { + close(r.in) + <-r.done +} - return out, nil - }) +// --- detachIntake --- +// +// detachIntake separates eager src reading from runner-paced forwarding, +// and owns the queue and suspend state. +// +// The reader goroutine pulls from the bidi framework's inCh as soon as +// inputs arrive and appends them to an internal queue. This is what makes +// detach detection immediate: the moment an input with +// [SessionFlowInput.Detach] lands in src, the reader sees it without +// waiting for the runner to finish whatever it's processing. +// +// The forwarder goroutine pops the queue and writes to dst, blocking on +// the runner via turnDone so it stays in step with turn pacing. +// +// The runner asks beginTurnEnd at the end of each turn: if suspended +// (detach has landed), the runner skips its turn-end snapshot — the +// pending row already captures the invocation and a single finalize +// will rewrite it with the cumulative state once the queued inputs +// drain. If not suspended, a normal turn-end snapshot is written. +// +// suspend is called once by the detach handler under the same mutex +// that beginTurnEnd reads from, ensuring memory ordering: any +// beginTurnEnd that returns after suspend completes sees suspended=true. + +type detachIntake struct { + src <-chan *SessionFlowInput + dst chan *SessionFlowInput + notify chan struct{} // buffered size 1; wakes forwarder when queue grows + + // turnDone is signaled by beginTurnEnd to release the forwarder so it + // may pop the next input. Initialized with one token so the very + // first turn can start without a preceding turn end. + turnDone chan struct{} + + mu sync.Mutex + suspended bool + queue []*SessionFlowInput + + readDone atomic.Bool + detachCh chan struct{} // signaled by reader when detach observed + + stop chan struct{} + stopOnce sync.Once + done chan struct{} +} - return &SessionFlow[Stream, State]{flow: flow} +func startDetachIntake(src <-chan *SessionFlowInput) *detachIntake { + i := &detachIntake{ + src: src, + dst: make(chan *SessionFlowInput), + notify: make(chan struct{}, 1), + turnDone: make(chan struct{}, 1), + detachCh: make(chan struct{}, 1), + stop: make(chan struct{}), + done: make(chan struct{}), + } + i.turnDone <- struct{}{} // initial credit for the first turn + go i.run() + return i } -// promptMessageKey is the metadata key used to tag prompt-rendered messages -// so they can be excluded from session history after generation. -const promptMessageKey = "_genkit_prompt" +func (i *detachIntake) run() { + defer close(i.done) -// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an -// automatic conversation loop. Each turn renders the prompt, appends -// conversation history, calls GenerateWithRequest, streams chunks to the -// client, and adds the model response to the session. -// -// The prompt is looked up by name from the registry using -// [ai.LookupDataPrompt]. The defaultInput is used for prompt rendering -// unless overridden per invocation via WithInputVariables. -func DefineSessionFlowFromPrompt[State, PromptIn any]( - r api.Registry, - promptName string, - defaultInput PromptIn, - opts ...SessionFlowOption[State], -) *SessionFlow[any, State] { - p := ai.LookupDataPrompt[PromptIn, string](r, promptName) - if p == nil { - panic(fmt.Sprintf("DefineSessionFlowFromPrompt: prompt %q not found", promptName)) - } + forwarderDone := make(chan struct{}) + go func() { + defer close(forwarderDone) + defer close(i.dst) + i.forward() + }() - fn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*SessionFlowResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - // Resolve prompt input: session state override > default. - promptInput := defaultInput - if stored := sess.InputVariables(); stored != nil { - typed, ok := base.ConvertTo[PromptIn](stored) - if !ok { - return core.NewError(core.INVALID_ARGUMENT, "input variables type mismatch: got %T, want %T", stored, promptInput) - } - promptInput = typed - } + i.read() + <-forwarderDone +} - // Render the prompt template. - genOpts, err := p.Render(ctx, promptInput) - if err != nil { - return fmt.Errorf("prompt render: %w", err) - } +// signal wakes the forwarder. Non-blocking: the channel is buffered size +// 1, so a pending signal is enough. +func (i *detachIntake) signal() { + select { + case i.notify <- struct{}{}: + default: + } +} - // Tag prompt-rendered messages so we can exclude them from - // session history after generation. - for _, m := range genOpts.Messages { - if m.Metadata == nil { - m.Metadata = make(map[string]any) - } - m.Metadata[promptMessageKey] = true +// read pulls eagerly from src into the internal queue and detects detach +// the moment it lands. When detach is observed, it drains any remaining +// buffered src non-blockingly (so all pre-detach inputs are accounted +// for), signals the detach handler, and exits. +func (i *detachIntake) read() { + defer func() { + i.readDone.Store(true) + i.signal() + }() + + for { + select { + case input, ok := <-i.src: + if !ok { + return + } + if input.Detach { + i.handleDetach(input) + return } + i.enqueue(input) + case <-i.stop: + return + } + } +} - // Append conversation history after the prompt-rendered messages. - genOpts.Messages = append(genOpts.Messages, sess.Messages()...) +func (i *detachIntake) enqueue(input *SessionFlowInput) { + i.mu.Lock() + i.queue = append(i.queue, input) + i.mu.Unlock() + i.signal() +} - // If tool restarts were provided, set the resume option so - // handleResumeOption re-executes the interrupted tools. - if len(input.ToolRestarts) > 0 { - for _, p := range input.ToolRestarts { - if !p.IsToolRequest() { - return core.NewError(core.INVALID_ARGUMENT, "ToolRestarts: part is not a tool request") - } - } - genOpts.Resume = ai.NewResume(input.ToolRestarts, nil) +// handleDetach drains any buffered src inputs into the queue and signals +// the detach handler. The detach handler then calls suspend to halt +// turn-end snapshots while the queued inputs finish processing. +// +// A pure detach signal (no Messages, no ToolRestarts) is dropped rather +// than enqueued: it carries no payload to process, so it would just +// trigger a no-op turn. Callers that want to ride a final input on the +// detach signal can do so by calling +// Send(&SessionFlowInput{Detach: true, Messages: ...}) explicitly. +func (i *detachIntake) handleDetach(first *SessionFlowInput) { + var drained []*SessionFlowInput + if hasInputPayload(first) { + drained = append(drained, first) + } +drainLoop: + for { + select { + case more, ok := <-i.src: + if !ok { + break drainLoop } + drained = append(drained, more) + default: + break drainLoop + } + } - // Call the model with streaming. - modelResp, err := ai.GenerateWithRequest(ctx, r, genOpts, nil, - func(ctx context.Context, chunk *ai.ModelResponseChunk) error { - resp.SendModelChunk(chunk) - return nil - }, - ) - if err != nil { - return fmt.Errorf("generate: %w", err) - } + if len(drained) > 0 { + i.mu.Lock() + i.queue = append(i.queue, drained...) + i.mu.Unlock() + i.signal() + } - // Replace session messages with the full history minus prompt - // messages. This captures intermediate tool call/response - // messages from the tool loop, not just the final response. - if modelResp.Request != nil { - var msgs []*ai.Message - for _, m := range modelResp.History() { - if m.Metadata != nil && m.Metadata[promptMessageKey] == true { - continue - } - msgs = append(msgs, m) - } - sess.SetMessages(msgs) - } else if modelResp.Message != nil { - sess.AddMessages(modelResp.Message) - } + select { + case i.detachCh <- struct{}{}: + case <-i.stop: + } +} - // Stream interrupt parts so the client can detect and - // handle them (e.g. prompt the user for confirmation). - if modelResp.FinishReason == ai.FinishReasonInterrupted { - if parts := modelResp.Interrupts(); len(parts) > 0 { - resp.SendModelChunk(&ai.ModelResponseChunk{ - Role: ai.RoleTool, - Content: parts, - }) - } - } +// hasInputPayload reports whether the input carries data the runner would +// otherwise process. Used to filter pure detach signals out of the +// queue so they don't trigger no-op turns. +func hasInputPayload(in *SessionFlowInput) bool { + return in != nil && (len(in.Messages) > 0 || len(in.ToolRestarts) > 0) +} + +// forward pops the queue and writes to dst at the runner's pace. The +// runner signals turnDone via beginTurnEnd when it's ready for the next +// input; until then the forwarder waits, so it never gets ahead of the +// runner. +func (i *detachIntake) forward() { + for { + // Wait for the previous turn to release us (initial credit lets + // the first turn through immediately). + select { + case <-i.turnDone: + case <-i.stop: + return + } + input := i.awaitInput() + if input == nil { + return // reader done with empty queue, or stop signaled + } + forwarded := *input + forwarded.Detach = false + select { + case i.dst <- &forwarded: + case <-i.stop: + return + } + } +} +// awaitInput blocks until the queue has an input, the reader is done, or +// stop is signaled. Returns the popped input or nil if no further inputs +// will arrive. +func (i *detachIntake) awaitInput() *SessionFlowInput { + for { + i.mu.Lock() + if len(i.queue) > 0 { + input := i.queue[0] + i.queue = i.queue[1:] + i.mu.Unlock() + return input + } + done := i.readDone.Load() + i.mu.Unlock() + if done { + return nil + } + select { + case <-i.notify: + case <-i.stop: return nil - }); err != nil { - return nil, err } - return sess.Result(), nil } +} - return DefineSessionFlow(r, promptName, fn, opts...) +// releaseForward releases the forwarder so it can pop the next input. +// Must be called from beginTurnEnd (and only there) so the forwarder +// stays in step with the runner's turn pacing. +func (i *detachIntake) releaseForward() { + select { + case i.turnDone <- struct{}{}: + default: + } +} + +func (i *detachIntake) out() <-chan *SessionFlowInput { + return i.dst +} + +func (i *detachIntake) detachSignal() <-chan struct{} { + return i.detachCh +} + +// beginTurnEnd is called by [SessionRunner.maybeSnapshot] before writing +// a turn-end snapshot. If the intake has been suspended (detach landed), +// it returns suspended=true and the runner skips the snapshot. +// +// In all cases (including suspended) the forwarder is released so it can +// pop the next queued input — suspension stops snapshot writing, not +// processing. +func (i *detachIntake) beginTurnEnd() (suspended bool) { + i.mu.Lock() + suspended = i.suspended + i.mu.Unlock() + i.releaseForward() + return suspended +} + +// suspend is called once by the detach handler. It flips suspended=true +// under the mutex so subsequent beginTurnEnd calls observe the change +// and skip their turn-end snapshot writes; the queued inputs roll into +// a single finalize rewrite of the pending row instead. +func (i *detachIntake) suspend() { + i.mu.Lock() + i.suspended = true + i.mu.Unlock() +} + +// stopAndWait forces the intake to exit and waits for both reader and +// forwarder goroutines. +func (i *detachIntake) stopAndWait() { + i.stopOnce.Do(func() { close(i.stop) }) + <-i.done } +// --- SessionFlow client API --- + // StreamBidi starts a new session flow invocation with bidirectional streaming. // Use this for multi-turn interactions where you need to send multiple inputs // and receive streaming chunks. For single-turn usage, see Run and RunText. @@ -454,16 +1051,14 @@ func (af *SessionFlow[Stream, State]) StreamBidi( ctx context.Context, opts ...InvocationOption[State], ) (*SessionFlowConnection[Stream, State], error) { - invOpts, err := af.resolveOptions(opts) + init, err := af.resolveOptions(opts) if err != nil { return nil, err } - - conn, err := af.flow.StreamBidi(ctx, invOpts) + conn, err := af.action.StreamBidi(ctx, init) if err != nil { return nil, err } - return &SessionFlowConnection[Stream, State]{conn: conn}, nil } @@ -479,21 +1074,31 @@ func (af *SessionFlow[Stream, State]) Run( if err != nil { return nil, err } - + // If the bidi function fails fast (e.g. resuming from an errored + // snapshot rejects in newSessionFlowRuntime), Send / Close / Receive + // see a closed connection and return generic "action has completed" + // errors. The real fn error is on Output(). Prefer it whenever it's + // non-nil so callers get the meaningful failure. if err := conn.Send(input); err != nil { + if _, outErr := conn.Output(); outErr != nil { + return nil, outErr + } return nil, err } if err := conn.Close(); err != nil { + if _, outErr := conn.Output(); outErr != nil { + return nil, outErr + } return nil, err } - - // Drain stream chunks. for _, err := range conn.Receive() { if err != nil { + if _, outErr := conn.Output(); outErr != nil { + return nil, outErr + } return nil, err } } - return conn.Output() } @@ -512,62 +1117,25 @@ func (af *SessionFlow[Stream, State]) RunText( // resolveOptions applies invocation options and returns the init struct. func (af *SessionFlow[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*SessionFlowInit[State], error) { - invOpts := &invocationOptions[State]{} + cfg := &invocationOptions[State]{} for _, opt := range opts { - if err := opt.applyInvocation(invOpts); err != nil { - return nil, fmt.Errorf("SessionFlow %q: %w", af.flow.Name(), err) + if err := opt.applyInvocation(cfg); err != nil { + return nil, fmt.Errorf("SessionFlow %q: %w", af.action.Name(), err) } } - init := &SessionFlowInit[State]{ - SnapshotID: invOpts.snapshotID, - State: invOpts.state, + SnapshotID: cfg.snapshotID, + State: cfg.state, } - if invOpts.promptInput != nil { + if cfg.promptInput != nil { if init.State == nil { init.State = &SessionState[State]{} } - init.State.InputVariables = invOpts.promptInput + init.State.InputVariables = cfg.promptInput } - return init, nil } -// newSessionFromInit creates a Session from initialization data. -// If resuming from a snapshot, the loaded snapshot is also returned. -func newSessionFromInit[State any]( - ctx context.Context, - init *SessionFlowInit[State], - store SessionStore[State], -) (*Session[State], *SessionSnapshot[State], error) { - s := &Session[State]{store: store} - - var snapshot *SessionSnapshot[State] - if init != nil { - if init.SnapshotID != "" && init.State != nil { - return nil, nil, core.NewError(core.INVALID_ARGUMENT, "snapshot ID and state are mutually exclusive") - } - if init.SnapshotID != "" && store == nil { - return nil, nil, core.NewError(core.FAILED_PRECONDITION, "snapshot ID %q provided but no session store configured", init.SnapshotID) - } - if init.SnapshotID != "" && store != nil { - var err error - snapshot, err = store.GetSnapshot(ctx, init.SnapshotID) - if err != nil { - return nil, nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) - } - if snapshot == nil { - return nil, nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) - } - s.state = snapshot.State - } else if init.State != nil { - s.state = *init.State - } - } - - return s, snapshot, nil -} - // --- SessionFlowConnection --- // SessionFlowConnection wraps BidiConnection with session flow-specific functionality. @@ -625,6 +1193,25 @@ func (c *SessionFlowConnection[Stream, State]) SendToolRestarts(parts ...*ai.Par return c.conn.Send(&SessionFlowInput{ToolRestarts: parts}) } +// Detach asks the server to write a pending snapshot, close the +// connection, and continue processing any already-buffered inputs in +// the background. Output() returns the pending snapshot ID; the client +// can later call AbortSnapshot to stop the background work or +// GetSnapshot to observe its progression. The pending snapshot is +// finalized with the cumulative final state once the queued inputs +// are processed. +// +// Streamed chunks emitted after detach are not forwarded over the wire +// (the connection is gone), but their session-level side effects still +// apply: artifacts sent via [Responder.SendArtifact] land in the +// session and end up in the final snapshot's state. +// +// To send a final input as part of the same wire message, use +// Send(&SessionFlowInput{Detach: true, Messages: ...}) directly. +func (c *SessionFlowConnection[Stream, State]) Detach() error { + return c.conn.Send(&SessionFlowInput{Detach: true}) +} + // Close signals that no more inputs will be sent. func (c *SessionFlowConnection[Stream, State]) Close() error { return c.conn.Close() @@ -637,23 +1224,26 @@ func (c *SessionFlowConnection[Stream, State]) Close() error { func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { c.initReceiver() return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { - for { - chunk, ok := <-c.chunks - if !ok { - if err := c.chunkErr; err != nil { - yield(nil, err) - } - return - } + for chunk := range c.chunks { if !yield(chunk, nil) { return } } + if err := c.chunkErr; err != nil { + yield(nil, err) + } } } // Output returns the final response after the session flow completes. +// +// Unlike the underlying BidiConnection, Output waits for the flow to +// finalize before returning. This is important for detached invocations: +// when the client sends Detach, the flow function returns promptly with a +// pending snapshot ID, and callers need to observe that output rather than +// the context cancellation error. func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[State], error) { + <-c.conn.Done() return c.conn.Output() } @@ -661,3 +1251,132 @@ func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[Stat func (c *SessionFlowConnection[Stream, State]) Done() <-chan struct{} { return c.conn.Done() } + +// --- DefineSessionFlowFromPrompt --- + +// promptMessageKey tags prompt-rendered messages so they can be excluded +// from session history after generation. They're rendered fresh each turn +// from the registered prompt, so persisting them in history would cause +// duplication on resume. +const promptMessageKey = "_genkit_prompt" + +// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an +// automatic conversation loop. Each turn renders the prompt, appends +// conversation history, calls GenerateWithRequest, streams chunks to the +// client, and adds the model response to the session. +// +// The prompt is looked up by name from the registry using +// [ai.LookupDataPrompt]. The defaultInput is used for prompt rendering +// unless overridden per invocation via WithInputVariables. +func DefineSessionFlowFromPrompt[State, PromptIn any]( + r api.Registry, + promptName string, + defaultInput PromptIn, + opts ...SessionFlowOption[State], +) *SessionFlow[any, State] { + p := ai.LookupDataPrompt[PromptIn, string](r, promptName) + if p == nil { + panic(fmt.Sprintf("DefineSessionFlowFromPrompt: prompt %q not found", promptName)) + } + + turn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State], input *SessionFlowInput) error { + genOpts, err := renderPromptForTurn(ctx, p, sess, defaultInput) + if err != nil { + return err + } + + if len(input.ToolRestarts) > 0 { + for _, part := range input.ToolRestarts { + if !part.IsToolRequest() { + return core.NewError(core.INVALID_ARGUMENT, "ToolRestarts: part is not a tool request") + } + } + genOpts.Resume = ai.NewResume(input.ToolRestarts, nil) + } + + modelResp, err := ai.GenerateWithRequest(ctx, r, genOpts, nil, + func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + resp.SendModelChunk(chunk) + return nil + }, + ) + if err != nil { + return fmt.Errorf("generate: %w", err) + } + + // Replace session messages with the full history minus prompt + // messages. This captures intermediate tool call/response messages + // from the tool loop, not just the final response. + if modelResp.Request != nil { + var msgs []*ai.Message + for _, m := range modelResp.History() { + if m.Metadata[promptMessageKey] == true { + continue + } + msgs = append(msgs, m) + } + sess.SetMessages(msgs) + } else if modelResp.Message != nil { + sess.AddMessages(modelResp.Message) + } + + // Stream interrupt parts so the client can detect and handle them + // (e.g. prompt the user for confirmation). + if modelResp.FinishReason == ai.FinishReasonInterrupted { + if parts := modelResp.Interrupts(); len(parts) > 0 { + resp.SendModelChunk(&ai.ModelResponseChunk{ + Role: ai.RoleTool, + Content: parts, + }) + } + } + return nil + } + + fn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*SessionFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return turn(ctx, resp, sess, input) + }) + if err != nil { + return nil, err + } + return sess.Result(), nil + } + + return DefineSessionFlow(r, promptName, fn, opts...) +} + +// renderPromptForTurn renders the prompt with the active input variables +// (session override > default), tags the prompt-rendered messages so they +// can be excluded from history, and appends conversation history. +func renderPromptForTurn[State, PromptIn any]( + ctx context.Context, + p *ai.DataPrompt[PromptIn, string], + sess *SessionRunner[State], + defaultInput PromptIn, +) (*ai.GenerateActionOptions, error) { + promptInput := defaultInput + if stored := sess.InputVariables(); stored != nil { + typed, ok := base.ConvertTo[PromptIn](stored) + if !ok { + return nil, core.NewError(core.INVALID_ARGUMENT, + "input variables type mismatch: got %T, want %T", stored, promptInput) + } + promptInput = typed + } + + genOpts, err := p.Render(ctx, promptInput) + if err != nil { + return nil, fmt.Errorf("prompt render: %w", err) + } + + for _, m := range genOpts.Messages { + if m.Metadata == nil { + m.Metadata = make(map[string]any) + } + m.Metadata[promptMessageKey] = true + } + + genOpts.Messages = append(genOpts.Messages, sess.Messages()...) + return genOpts, nil +} diff --git a/go/ai/exp/session_flow_test.go b/go/ai/exp/session_flow_test.go index fb80555fb8..26bf20d842 100644 --- a/go/ai/exp/session_flow_test.go +++ b/go/ai/exp/session_flow_test.go @@ -18,11 +18,15 @@ package exp import ( "context" + "errors" "fmt" "strings" "testing" + "time" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" ) @@ -626,100 +630,133 @@ func TestSessionFlow_SetMessages(t *testing.T) { } } -func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { - ctx := context.Background() - reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() - - af := DefineSessionFlow(reg, "metadataFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - sess.AddMessages(ai.NewModelTextMessage("reply")) - return nil +func TestInMemorySessionStore(t *testing.T) { + t.Run("GetMissing", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + snap, err := store.GetSnapshot(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap != nil { + t.Errorf("expected nil, got %v", snap) + } + }) + + t.Run("SaveWithFixedID", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + saved, err := store.SaveSnapshot(context.Background(), "snap-1", + func(existing *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + if existing != nil { + t.Errorf("expected nil existing on first save, got %+v", existing) + } + return &SessionSnapshot[testState]{ + Status: SnapshotStatusComplete, + State: SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil }) - if err != nil { - return nil, err - } - msgs := sess.Messages() - return &SessionFlowResult{Message: msgs[len(msgs)-1]}, nil - }, - WithSessionStore(store), - ) - - conn, err := af.StreamBidi(ctx) - if err != nil { - t.Fatalf("StreamBidi failed: %v", err) - } - - conn.SendText("hello") - for chunk, err := range conn.Receive() { if err != nil { - t.Fatalf("Receive error: %v", err) + t.Fatalf("SaveSnapshot failed: %v", err) } - if chunk.TurnEnd != nil { - break + if saved.SnapshotID != "snap-1" { + t.Errorf("saved SnapshotID = %q, want %q", saved.SnapshotID, "snap-1") } - } - conn.Close() - - response, err := conn.Output() - if err != nil { - t.Fatalf("Output failed: %v", err) - } - - // The last model message should have snapshotId in its metadata. - if response.Message == nil { - t.Fatal("expected Message in response") - } - if response.Message.Metadata == nil { - t.Fatal("expected metadata on last message") - } - if _, ok := response.Message.Metadata["snapshotId"]; !ok { - t.Error("expected snapshotId in last message metadata") - } -} - -func TestInMemorySessionStore(t *testing.T) { - ctx := context.Background() - store := NewInMemorySessionStore[testState]() - - // Get non-existent. - snap, err := store.GetSnapshot(ctx, "nonexistent") - if err != nil { - t.Fatalf("GetSnapshot failed: %v", err) - } - if snap != nil { - t.Errorf("expected nil, got %v", snap) - } + if saved.CreatedAt.IsZero() || saved.UpdatedAt.IsZero() { + t.Errorf("expected CreatedAt/UpdatedAt stamped, got created=%v updated=%v", + saved.CreatedAt, saved.UpdatedAt) + } + }) + + t.Run("GetReturnsCopy", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Status: SnapshotStatusComplete, + State: SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + retrieved, _ := store.GetSnapshot(context.Background(), "snap-1") + retrieved.State.Custom.Counter = 999 + retrieved2, _ := store.GetSnapshot(context.Background(), "snap-1") + if retrieved2.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) + } + }) - // Save and retrieve. - snapshot := &SessionSnapshot[testState]{ - SnapshotID: "snap-1", - State: SessionState[testState]{ - Custom: testState{Counter: 1}, - }, - } - if err := store.SaveSnapshot(ctx, snapshot); err != nil { - t.Fatalf("SaveSnapshot failed: %v", err) - } + t.Run("DefaultsEmptyStatusToComplete", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + saved, err := store.SaveSnapshot(context.Background(), "", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{}, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + if saved.SnapshotID == "" { + t.Error("expected store to generate SnapshotID") + } + if saved.Status != SnapshotStatusComplete { + t.Errorf("expected Status=complete by default, got %q", saved.Status) + } + }) - retrieved, err := store.GetSnapshot(ctx, "snap-1") - if err != nil { - t.Fatalf("GetSnapshot failed: %v", err) - } - if retrieved == nil { - t.Fatal("expected snapshot") - } - if retrieved.State.Custom.Counter != 1 { - t.Errorf("expected counter=1, got %d", retrieved.State.Custom.Counter) - } + t.Run("NoopFnSkipsWrite", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{Status: SnapshotStatusComplete}, nil + }); err != nil { + t.Fatalf("seed: %v", err) + } + before, _ := store.GetSnapshot(context.Background(), "snap-1") + noop, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return nil, nil + }) + if err != nil { + t.Fatalf("noop SaveSnapshot: %v", err) + } + if noop != nil { + t.Errorf("expected nil return on noop, got %+v", noop) + } + after, _ := store.GetSnapshot(context.Background(), "snap-1") + if before.UpdatedAt != after.UpdatedAt { + t.Errorf("noop should not bump UpdatedAt: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) + } + }) - // Verify isolation. - snapshot.State.Custom.Counter = 999 - retrieved2, _ := store.GetSnapshot(ctx, "snap-1") - if retrieved2.State.Custom.Counter != 1 { - t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) - } + t.Run("PreservesCreatedAtOnUpdate", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + saved, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{Status: SnapshotStatusComplete}, nil + }) + if err != nil { + t.Fatalf("seed: %v", err) + } + time.Sleep(time.Millisecond) // ensure measurable UpdatedAt delta + updated, err := store.SaveSnapshot(context.Background(), "snap-1", + func(existing *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + if existing == nil { + t.Fatal("expected non-nil existing on update") + } + return &SessionSnapshot[testState]{ + Status: SnapshotStatusComplete, + State: SessionState[testState]{Custom: testState{Counter: 2}}, + }, nil + }) + if err != nil { + t.Fatalf("update: %v", err) + } + if !updated.CreatedAt.Equal(saved.CreatedAt) { + t.Errorf("CreatedAt not preserved: before=%v after=%v", saved.CreatedAt, updated.CreatedAt) + } + if !updated.UpdatedAt.After(saved.UpdatedAt) { + t.Errorf("UpdatedAt did not advance: before=%v after=%v", saved.UpdatedAt, updated.UpdatedAt) + } + }) } func TestSessionFlow_TurnSpanOutput(t *testing.T) { @@ -1710,3 +1747,1356 @@ func TestSessionFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) t.Error("expected parent ID (turn-end snapshot)") } } + +// --- Detach, transform, and getSnapshot tests --- + +// waitForSnapshot polls the store for a snapshot matching the predicate, +// failing the test if it doesn't show up within the timeout. +func waitForSnapshot[State any]( + t *testing.T, + store SessionStore[State], + id string, + timeout time.Duration, + predicate func(*SessionSnapshot[State]) bool, +) *SessionSnapshot[State] { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + snap, err := store.GetSnapshot(context.Background(), id) + if err != nil { + t.Fatalf("GetSnapshot(%q): %v", id, err) + } + if snap != nil && predicate(snap) { + return snap + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("snapshot %q did not satisfy predicate within %s", id, timeout) + return nil +} + +func TestSessionFlow_TurnEnd_CarriesSnapshotID(t *testing.T) { + // Sanity: each TurnEnd chunk carries the snapshot ID of the turn-end + // snapshot, and the snapshots themselves are persisted. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineSessionFlow(reg, "turnEndSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + var observed []TurnEnd + for turn := 0; turn < 3; turn++ { + if err := conn.SendText(fmt.Sprintf("turn %d", turn)); err != nil { + t.Fatalf("SendText: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.TurnEnd != nil { + observed = append(observed, *chunk.TurnEnd) + break + } + } + } + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + + if got := len(observed); got != 3 { + t.Fatalf("observed %d TurnEnd chunks, want 3", got) + } + for i, te := range observed { + if te.SnapshotID == "" { + t.Errorf("TurnEnd[%d].SnapshotID is empty", i) + continue + } + snap, err := store.GetSnapshot(context.Background(), te.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap == nil { + t.Errorf("turn %d: snapshot %q not in store", i, te.SnapshotID) + } + } +} + +func TestSessionFlow_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { + // Detach lands while turn 0 (input A) is mid-fn and an extra turn + // (the detach input D itself) is waiting. The pending snapshot must: + // - Be written with empty state and no parent (A was suspended, so + // no turn-end snapshot landed before pending). + // - NOT write a separate turn-end snapshot for A or D (suspended). + // After release, the finalized snapshot has both A's and D's effects. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + entered := make(chan struct{}, 4) + release := make(chan struct{}) + + af := DefineSessionFlow(reg, "detachInFlight", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + entered <- struct{}{} + <-release + sess.AddMessages(ai.NewModelTextMessage("reply-" + input.Messages[0].Text())) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + // Drain stream chunks in the background. + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + // Send A and wait for it to enter fn (so it's in-flight when detach + // arrives). + if err := conn.SendText("A"); err != nil { + t.Fatalf("SendText A: %v", err) + } + select { + case <-entered: + case <-time.After(2 * time.Second): + t.Fatal("A did not enter fn") + } + + // Send D, then Detach. The eager intake reader sees D queued and the + // detach signal immediately, even though the runner is blocked on A. + if err := conn.SendText("D"); err != nil { + t.Fatalf("SendText D: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected pending snapshot ID") + } + + pending, err := store.GetSnapshot(context.Background(), out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot pending: %v", err) + } + if pending.Status != SnapshotStatusPending { + t.Errorf("pending snapshot status = %q, want pending", pending.Status) + } + if got := len(pending.State.Messages); got != 0 { + t.Errorf("pending state messages = %d, want 0 (live state not yet committed)", got) + } + + // No separate turn-end snapshot for A should have been written. + // (Walk the parent chain — pending should have no parent in this + // invocation since A was the first turn and got suspended.) + if pending.ParentID != "" { + t.Errorf("pending ParentID = %q, want empty (A was suspended)", pending.ParentID) + } + + close(release) + + final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusComplete + }) + if final.State.Custom.Counter != 2 { + t.Errorf("final counter = %d, want 2 (A + D both processed)", final.State.Custom.Counter) + } + if got := len(final.State.Messages); got != 4 { + // 2 user (A, D) + 2 model replies = 4. + t.Errorf("final messages = %d, want 4", got) + } +} + +func TestSessionFlow_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { + // Run two normal turns first, then detach during a third (in-flight) + // turn. The pending snapshot must chain off the second turn's snapshot. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + enter := make(chan struct{}, 4) + release := make(chan struct{}, 4) + + af := DefineSessionFlow(reg, "detachChainParent", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + enter <- struct{}{} + <-release + sess.AddMessages(ai.NewModelTextMessage("ok")) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + // Background drainer. + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + // Run two normal turns. + for i := 0; i < 2; i++ { + release <- struct{}{} // pre-load release so this turn's fn doesn't block + if err := conn.SendText(fmt.Sprintf("sync-%d", i)); err != nil { + t.Fatalf("SendText: %v", err) + } + <-enter + } + // Brief settle so the second turn-end snapshot lands before detach. + time.Sleep(20 * time.Millisecond) + // Drain enter signal if buffered. + for len(enter) > 0 { + <-enter + } + + // Now start a third turn but DON'T release it — the third turn is + // in-flight when detach lands. + if err := conn.SendText("inflight"); err != nil { + t.Fatalf("SendText inflight: %v", err) + } + <-enter // third turn entered fn + + // Send the queued input and detach. + if err := conn.SendText("detach-msg"); err != nil { + t.Fatalf("SendText detach-msg: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + pending, err := store.GetSnapshot(context.Background(), out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if pending.ParentID == "" { + t.Error("pending ParentID empty; expected parent = last sync turn snapshot") + } + + // Release remaining turns and let finalize run. + close(release) + waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusComplete + }) +} + +func TestSessionFlow_Detach_RequiresStore(t *testing.T) { + reg := newTestRegistry(t) + + af := DefineSessionFlow(reg, "detachNoStore", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return nil + }) + }, + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach send: %v", err) + } + conn.Close() + + _, err = conn.Output() + if err == nil { + t.Fatal("expected error when detaching without a session store") + } + if !strings.Contains(err.Error(), "detach requires a session store") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestSessionFlow_Detach_PendingThenComplete(t *testing.T) { + // Client detaches mid-flow; flow finishes naturally; pending snapshot + // flips to status=complete with the full session state. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + release := make(chan struct{}) + entered := make(chan struct{}) + + af := DefineSessionFlow(reg, "detachComplete", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + select { + case entered <- struct{}{}: + case <-ctx.Done(): + } + <-release + sess.AddMessages(ai.NewModelTextMessage("finished")) + sess.UpdateCustom(func(s testState) testState { + s.Counter = 42 + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + // Drain chunks so the responder isn't blocked. + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + + select { + case <-entered: + case <-time.After(2 * time.Second): + t.Fatal("flow did not enter work phase") + } + + // Output returns the pending snapshot ID immediately; the snapshot + // itself should be in the store with status=pending. + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected snapshot ID after detach") + } + + pending, err := store.GetSnapshot(context.Background(), out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot pending: %v", err) + } + if pending == nil { + t.Fatal("pending snapshot not written") + } + if pending.Status != SnapshotStatusPending { + t.Errorf("expected status=%q, got %q", SnapshotStatusPending, pending.Status) + } + if got := len(pending.State.Messages); got != 0 { + t.Errorf("pending snapshot should not carry message history, got %d messages", got) + } + + // Release; finalizer rewrites the snapshot with the terminal state. + close(release) + + finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusComplete + }) + if finalSnap.State.Custom.Counter != 42 { + t.Errorf("expected counter=42 in final snapshot, got %d", finalSnap.State.Custom.Counter) + } + if got := len(finalSnap.State.Messages); got < 2 { + t.Errorf("expected at least 2 messages in final snapshot, got %d", got) + } +} + +func TestSessionFlow_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { + // SendArtifact must behave the same way regardless of whether detach + // has landed: the artifact is added to the session and shows up in + // the finalized snapshot's state. The wire forward is the only thing + // detach suppresses, so flow authors don't need to branch on detach. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + detached := make(chan struct{}) + release := make(chan struct{}) + + af := DefineSessionFlow(reg, "detachArtifact", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + resp.SendArtifact(&Artifact{ + Name: "before.txt", + Parts: []*ai.Part{ai.NewTextPart("pre-detach")}, + }) + select { + case <-detached: + case <-ctx.Done(): + return ctx.Err() + } + resp.SendArtifact(&Artifact{ + Name: "after.txt", + Parts: []*ai.Part{ai.NewTextPart("post-detach")}, + }) + <-release + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected pending snapshot ID") + } + + close(detached) + close(release) + + final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusComplete + }) + + names := make(map[string]bool, len(final.State.Artifacts)) + for _, a := range final.State.Artifacts { + names[a.Name] = true + } + if !names["before.txt"] { + t.Errorf("pre-detach artifact missing from final snapshot: %v", final.State.Artifacts) + } + if !names["after.txt"] { + t.Errorf("post-detach artifact missing from final snapshot: %v", final.State.Artifacts) + } +} + +func TestSessionFlow_Detach_FlowErrorsBecomesError(t *testing.T) { + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + release := make(chan struct{}) + entered := make(chan struct{}) + boom := errors.New("kaboom") + + af := DefineSessionFlow(reg, "detachErr", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + select { + case entered <- struct{}{}: + case <-time.After(time.Second): + } + <-release + return boom + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected pending snapshot ID") + } + + close(release) + + snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusError + }) + if !strings.Contains(snap.Error, "kaboom") { + t.Errorf("expected snapshot.Error to contain %q, got %q", "kaboom", snap.Error) + } + + // Resuming from an errored detached snapshot is rejected. + _, err = af.RunText(context.Background(), "retry", WithSnapshotID[testState](out.SnapshotID)) + if err == nil { + t.Fatal("expected error when resuming from errored snapshot") + } + if !strings.Contains(err.Error(), "kaboom") { + t.Errorf("unexpected resume error: %v", err) + } +} + +func TestSessionFlow_Detach_AbortSnapshotStopsFlow(t *testing.T) { + // Client detaches, then calls AbortSnapshot. The store's status + // subscriber notifies the runtime, which cancels the work context, and + // the finalizer rewrites the snapshot with status=canceled. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + entered := make(chan struct{}) + + af := DefineSessionFlow(reg, "detachAbort", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + select { + case entered <- struct{}{}: + case <-time.After(time.Second): + } + <-ctx.Done() + return ctx.Err() + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected pending snapshot ID") + } + + // Abort via the store. The local caller already has the store + // reference from WithSessionStore. + meta, err := store.AbortSnapshot(context.Background(), out.SnapshotID) + if err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + if meta.Status != SnapshotStatusCanceled { + t.Errorf("AbortSnapshot status = %q, want canceled", meta.Status) + } + + // The subscriber wakes the runtime, cancels work, and the finalizer + // rewrites the snapshot with the canceled status. + finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusCanceled && s.UpdatedAt.After(s.CreatedAt) + }) + if finalSnap.State.Custom.Counter != 0 { + // The flow only blocked on ctx — no state mutation expected. + t.Errorf("unexpected counter value in canceled snapshot: %d", finalSnap.State.Custom.Counter) + } +} + +func TestSessionFlow_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { + // Sanity: a non-detached invocation against a store-backed flow still + // behaves like a synchronous flow (turn-end snapshots, no pending row). + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineSessionFlow(reg, "syncStillWorks", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + + var turnEndID string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.TurnEnd != nil { + turnEndID = chunk.TurnEnd.SnapshotID + break + } + } + if turnEndID == "" { + t.Fatal("expected snapshot ID on TurnEnd chunk") + } + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + + snap, err := store.GetSnapshot(context.Background(), turnEndID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.Status != SnapshotStatusComplete { + t.Errorf("turn-end snapshot status = %q, want complete", snap.Status) + } + if snap.Event != SnapshotEventTurnEnd { + t.Errorf("turn-end snapshot event = %q, want %q", snap.Event, SnapshotEventTurnEnd) + } +} + +func TestSessionFlow_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { + // Without detach, a client cancel still cancels the work — this is + // the regression guard for "until detach=true is called, this is a + // normal HTTP/WS connection that cancels on close." + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + entered := make(chan struct{}) + exited := make(chan error, 1) + + af := DefineSessionFlow(reg, "syncCancel", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + select { + case entered <- struct{}{}: + case <-ctx.Done(): + } + <-ctx.Done() + return ctx.Err() + }) + exited <- err + return nil, err + }, + WithSessionStore(store), + ) + + ctx, cancel := context.WithCancel(context.Background()) + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + <-entered + cancel() + + select { + case fnErr := <-exited: + if fnErr == nil { + t.Error("expected fn to exit with ctx error after client cancel") + } + case <-time.After(2 * time.Second): + t.Fatal("fn did not exit after client cancel") + } +} + +func TestSessionFlow_ResumeFromErrorSnapshot_Rejected(t *testing.T) { + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + erroredID := "errored-456" + if _, err := store.SaveSnapshot(context.Background(), erroredID, + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Event: SnapshotEventInvocationEnd, + Status: SnapshotStatusError, + Error: "underlying failure", + State: SessionState[testState]{}, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + + af := DefineSessionFlow(reg, "resumeErrored", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, nil + }, + WithSessionStore(store), + ) + + _, err := af.RunText(context.Background(), "hi", WithSnapshotID[testState](erroredID)) + if err == nil { + t.Fatal("expected error when resuming from errored snapshot") + } + if !strings.Contains(err.Error(), "underlying failure") { + t.Errorf("expected error to surface underlying failure, got: %v", err) + } +} + +func TestSessionFlow_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + // Transform that scrubs a specific word from all messages. + transform := func(_ context.Context, s SessionState[testState]) SessionState[testState] { + for _, msg := range s.Messages { + for _, p := range msg.Content { + if p.Text != "" { + p.Text = strings.ReplaceAll(p.Text, "secret", "[REDACTED]") + } + } + } + return s + } + + af := DefineSessionFlow(reg, "transformedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("the secret is out")) + return nil + }) + }, + WithSessionStore(store), + WithStateTransform[testState](transform), + ) + + ctx := context.Background() + out, err := af.RunText(ctx, "tell me the secret") + if err != nil { + t.Fatalf("RunText: %v", err) + } + + // Transform is action-layer behavior: invoke the registered action + // directly the way a non-Go client would. + action := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + reg, api.ActionTypeAgentSnapshot, "transformedFlow") + if action == nil { + t.Fatal("getSnapshot action not registered") + } + resp, err := action.Run(ctx, &GetSnapshotRequest{SnapshotID: out.SnapshotID}, nil) + if err != nil { + t.Fatalf("getSnapshot action: %v", err) + } + if resp.SnapshotID != out.SnapshotID { + t.Errorf("SnapshotID mismatch: got %q want %q", resp.SnapshotID, out.SnapshotID) + } + if resp.Status != SnapshotStatusComplete { + t.Errorf("expected status=complete, got %q", resp.Status) + } + if resp.State == nil { + t.Fatal("expected state in response") + } + // Both messages should be redacted: user message (from input) and model reply. + for i, msg := range resp.State.Messages { + for _, p := range msg.Content { + if strings.Contains(p.Text, "secret") { + t.Errorf("message %d still contains 'secret': %q", i, p.Text) + } + } + } + + // The stored snapshot must remain untransformed so the flow can be + // resumed faithfully. + stored, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot direct: %v", err) + } + foundRaw := false + for _, msg := range stored.State.Messages { + for _, p := range msg.Content { + if strings.Contains(p.Text, "secret") { + foundRaw = true + } + } + } + if !foundRaw { + t.Error("expected stored snapshot to retain the original 'secret' text") + } +} + +func TestInMemorySessionStore_GetSnapshot_NotFound(t *testing.T) { + store := NewInMemorySessionStore[testState]() + + snap, err := store.GetSnapshot(context.Background(), "nope") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap != nil { + t.Errorf("expected nil for missing snapshot, got %+v", snap) + } +} + +func TestSessionFlow_GetSnapshotAction_NoStore(t *testing.T) { + reg := newTestRegistry(t) + + DefineSessionFlow(reg, "noStoreFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, nil + }, + ) + + // Action remains registered even without a store; it returns + // FAILED_PRECONDITION when invoked. + action := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + reg, api.ActionTypeAgentSnapshot, "noStoreFlow") + if action == nil { + t.Fatal("getSnapshot action should be registered even without a store") + } + _, err := action.Run(context.Background(), &GetSnapshotRequest{SnapshotID: "any"}, nil) + if err == nil { + t.Fatal("expected error when store is not configured") + } + if !strings.Contains(err.Error(), "no session store configured") { + t.Errorf("unexpected error: %v", err) + } +} + +// minimalStore is a SessionStore that does NOT implement SnapshotAborter. +// Used to verify the abort action stays unregistered for stores that +// lack the capability. +type minimalStore[State any] struct{} + +func (minimalStore[State]) GetSnapshot(context.Context, string) (*SessionSnapshot[State], error) { + return nil, nil +} +func (minimalStore[State]) SaveSnapshot( + context.Context, string, + func(*SessionSnapshot[State]) (*SessionSnapshot[State], error), +) (*SessionSnapshot[State], error) { + return nil, nil +} + +func TestSessionFlow_AgentMetadata(t *testing.T) { + // Verify the metadata["agent"] payload on the flow's action descriptor + // correctly reports stateManagement and abortable for each combination + // of store capabilities. + noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, nil + } + + cases := []struct { + name string + define func(reg api.Registry, flowName string) + wantMgmt AgentMetadataStateManagement + wantAbortab bool + }{ + { + name: "no store → client-managed, not abortable", + define: func(reg api.Registry, flowName string) { + DefineSessionFlow(reg, flowName, noopFn) + }, + wantMgmt: AgentMetadataStateManagementClient, + wantAbortab: false, + }, + { + name: "store missing abort capabilities → server-managed, not abortable", + define: func(reg api.Registry, flowName string) { + DefineSessionFlow(reg, flowName, noopFn, + WithSessionStore[testState](minimalStore[testState]{})) + }, + wantMgmt: AgentMetadataStateManagementServer, + wantAbortab: false, + }, + { + name: "store with full capabilities → server-managed, abortable", + define: func(reg api.Registry, flowName string) { + DefineSessionFlow(reg, flowName, noopFn, + WithSessionStore(NewInMemorySessionStore[testState]())) + }, + wantMgmt: AgentMetadataStateManagementServer, + wantAbortab: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + reg := newTestRegistry(t) + flowName := "metaFlow" + tc.define(reg, flowName) + + act := core.ResolveActionFor[*SessionFlowInit[testState], *SessionFlowOutput[testState], *SessionFlowStreamChunk[testStatus], *SessionFlowInput]( + reg, api.ActionTypeFlow, flowName) + if act == nil { + t.Fatal("flow action not registered") + } + desc := act.Desc() + raw, ok := desc.Metadata["agent"] + if !ok { + t.Fatalf("metadata[\"agent\"] missing; got metadata = %+v", desc.Metadata) + } + meta, ok := raw.(AgentMetadata) + if !ok { + t.Fatalf("metadata[\"agent\"] type = %T, want AgentMetadata", raw) + } + if meta.StateManagement != tc.wantMgmt { + t.Errorf("stateManagement = %q, want %q", meta.StateManagement, tc.wantMgmt) + } + if meta.Abortable != tc.wantAbortab { + t.Errorf("abortable = %v, want %v", meta.Abortable, tc.wantAbortab) + } + }) + } +} + +func TestSessionFlow_AbortAction_GatedOnCapabilities(t *testing.T) { + // Verify the abort companion action is only registered when the + // store implements SnapshotAborter. The getSnapshot action is + // registered regardless. + t.Run("aborter capability → both registered", func(t *testing.T) { + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() // implements SnapshotAborter + DefineSessionFlow(reg, "fullCaps", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, nil + }, + WithSessionStore(store), + ) + getAction := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + reg, api.ActionTypeAgentSnapshot, "fullCaps") + if getAction == nil { + t.Error("getSnapshot action should be registered") + } + abortAction := core.ResolveActionFor[*AbortSnapshotRequest, *AbortSnapshotResponse, struct{}, struct{}]( + reg, api.ActionTypeAgentAbort, "fullCaps") + if abortAction == nil { + t.Error("abortSnapshot action should be registered when store implements SnapshotAborter") + } + }) + + t.Run("no aborter capability → abort not registered", func(t *testing.T) { + reg := newTestRegistry(t) + DefineSessionFlow(reg, "minCaps", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, nil + }, + WithSessionStore[testState](minimalStore[testState]{}), + ) + getAction := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + reg, api.ActionTypeAgentSnapshot, "minCaps") + if getAction == nil { + t.Error("getSnapshot action should be registered even when store lacks SnapshotAborter") + } + abortAction := core.ResolveActionFor[*AbortSnapshotRequest, *AbortSnapshotResponse, struct{}, struct{}]( + reg, api.ActionTypeAgentAbort, "minCaps") + if abortAction != nil { + t.Error("abortSnapshot action should NOT be registered when store lacks SnapshotAborter") + } + }) +} + +func TestSessionFlow_StateTransform_ClientManagedState(t *testing.T) { + reg := newTestRegistry(t) + + // Client-managed state: transform should be applied to SessionFlowOutput.State. + transform := func(_ context.Context, s SessionState[testState]) SessionState[testState] { + // Zero out the counter to demonstrate the transform is applied. + s.Custom.Counter = -1 + return s + } + + af := DefineSessionFlow(reg, "clientXformFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess.UpdateCustom(func(s testState) testState { + s.Counter = 7 + return s + }) + return nil + }) + }, + WithStateTransform[testState](transform), + ) + + out, err := af.RunText(context.Background(), "go") + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.State == nil { + t.Fatal("expected client-managed state in output") + } + if out.State.Custom.Counter != -1 { + t.Errorf("expected transformed counter=-1, got %d", out.State.Custom.Counter) + } +} + +func TestSessionFlow_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { + // End-to-end: run a flow that the client detaches from, let it + // finalize, then resume from its snapshot as if reconnecting later. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineSessionFlow(reg, "resumeDetachedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + ctx := context.Background() + + // First invocation: detach to write a pending snapshot, then wait + // for finalize. + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("turn 1"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + first, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + finalSnap := waitForSnapshot(t, store, first.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusComplete + }) + if finalSnap.State.Custom.Counter != 1 { + t.Fatalf("expected counter=1 in finalized snapshot, got %d", finalSnap.State.Custom.Counter) + } + + // Resume from the finalized snapshot. + second, err := af.RunText(ctx, "turn 2", WithSnapshotID[testState](first.SnapshotID)) + if err != nil { + t.Fatalf("resume RunText: %v", err) + } + + snap, err := store.GetSnapshot(ctx, second.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.State.Custom.Counter != 2 { + t.Errorf("expected counter=2 after resume, got %d", snap.State.Custom.Counter) + } +} + +func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore[testState]() + + // Abort on missing snapshot returns nil metadata, no error. + if meta, err := store.AbortSnapshot(ctx, "nope"); err != nil || meta != nil { + t.Fatalf("AbortSnapshot(missing) = %+v, %v; want nil, nil", meta, err) + } + + // Pending → canceled, UpdatedAt advances. + pending, err := store.SaveSnapshot(ctx, "snap-cas", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Event: SnapshotEventDetach, + Status: SnapshotStatusPending, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + time.Sleep(time.Millisecond) // ensure measurable UpdatedAt delta + meta, err := store.AbortSnapshot(ctx, "snap-cas") + if err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + if meta.Status != SnapshotStatusCanceled { + t.Errorf("status after first abort = %q, want canceled", meta.Status) + } + if !meta.UpdatedAt.After(pending.UpdatedAt) { + t.Errorf("UpdatedAt did not advance: %v vs %v", meta.UpdatedAt, pending.UpdatedAt) + } + + // Idempotent: second abort returns canceled, no error, no further mutation. + firstUpdate := meta.UpdatedAt + meta2, err := store.AbortSnapshot(ctx, "snap-cas") + if err != nil { + t.Fatalf("AbortSnapshot (second): %v", err) + } + if meta2.Status != SnapshotStatusCanceled { + t.Errorf("status after second abort = %q, want canceled", meta2.Status) + } + if !meta2.UpdatedAt.Equal(firstUpdate) { + t.Errorf("UpdatedAt advanced on idempotent abort: %v vs %v", meta2.UpdatedAt, firstUpdate) + } + + // Abort on terminal status is a no-op that returns the existing status. + if _, err := store.SaveSnapshot(ctx, "snap-complete", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Event: SnapshotEventTurnEnd, + Status: SnapshotStatusComplete, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + meta3, err := store.AbortSnapshot(ctx, "snap-complete") + if err != nil { + t.Fatalf("AbortSnapshot on complete: %v", err) + } + if meta3.Status != SnapshotStatusComplete { + t.Errorf("abort on complete returned status=%q, want complete", meta3.Status) + } +} + +func TestSessionFlow_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { + // An abort that lands while fn is still running but does not actually + // stop fn (because fn does not observe ctx) must still result in + // status=canceled — the finalizer must not clobber canceled with + // complete. The subscriber observes the status flip and the finalizer + // reads the resulting flag. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + fnRelease := make(chan struct{}) + entered := make(chan struct{}) + + af := DefineSessionFlow(reg, "raceFinalize", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + select { + case entered <- struct{}{}: + case <-time.After(time.Second): + } + <-fnRelease + // Return cleanly without observing ctx. Without the + // subscriber/recheck, this would land status=complete and + // clobber the abort. + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + + // Externally abort before releasing fn. + if _, err := store.AbortSnapshot(context.Background(), out.SnapshotID); err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + + close(fnRelease) + + finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusCanceled || s.Status == SnapshotStatusComplete + }) + if finalSnap.Status != SnapshotStatusCanceled { + t.Errorf("finalize clobbered canceled with %q", finalSnap.Status) + } +} + +func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := NewInMemorySessionStore[testState]() + + // Subscribe to a missing snapshot: channel returns immediately closed + // without yielding a value. + missing := store.OnSnapshotStatusChange(ctx, "nope") + if _, ok := <-missing; ok { + t.Errorf("expected channel for missing snapshot to be closed without a value") + } + + // Persist a pending snapshot so subsequent subscribers get an initial + // value plus updates on each status flip. + if _, err := store.SaveSnapshot(ctx, "snap-sub", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Event: SnapshotEventDetach, + Status: SnapshotStatusPending, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + + subCtx, subCancel := context.WithCancel(ctx) + defer subCancel() + statusCh := store.OnSnapshotStatusChange(subCtx, "snap-sub") + + // Initial value reflects current status. + select { + case status, ok := <-statusCh: + if !ok { + t.Fatal("channel closed before initial status") + } + if status != SnapshotStatusPending { + t.Errorf("initial status = %q, want pending", status) + } + case <-time.After(time.Second): + t.Fatal("did not receive initial status") + } + + // Abort flips status; subscriber observes canceled. + if _, err := store.AbortSnapshot(ctx, "snap-sub"); err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + select { + case status, ok := <-statusCh: + if !ok { + t.Fatal("channel closed before abort notification") + } + if status != SnapshotStatusCanceled { + t.Errorf("status notification = %q, want canceled", status) + } + case <-time.After(time.Second): + t.Fatal("did not receive abort notification") + } + + // Cancelling the subscription closes the channel. + subCancel() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + _, ok := <-statusCh + if !ok { + return + } + } + t.Fatal("channel did not close after subscription ctx cancel") +} + +func TestSessionFlow_AbortSnapshot_NoOpOnTerminal(t *testing.T) { + // Calling AbortSnapshot on an already-terminal snapshot is a no-op + // that returns the existing status. + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineSessionFlow(reg, "abortNoop", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + WithSessionStore(store), + ) + + ctx := context.Background() + out, err := af.RunText(ctx, "hi") + if err != nil { + t.Fatalf("RunText: %v", err) + } + + resp, err := store.AbortSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + if resp.Status != SnapshotStatusComplete { + t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusComplete, resp.Status) + } + + // Confirm the store snapshot was not flipped. + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.Status != SnapshotStatusComplete { + t.Errorf("snapshot status = %q after abort-on-terminal, want complete", snap.Status) + } +} diff --git a/go/core/action.go b/go/core/action.go index f0cc1a6d63..52fcf8307a 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -582,8 +582,21 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Receive() iter.Seq2[StreamOut } // Output returns the final output after the action completes. -// Blocks until done or context cancelled. +// Blocks until done or context cancelled. If the action has finished, its +// actual output is returned even when the context was cancelled concurrently. func (c *BidiConnection[StreamIn, StreamOut, Out]) Output() (Out, error) { + // Fast path: if the action has already finished, return its output + // rather than racing with ctx.Done. This matters for callers that + // observe a completed action just after cancelling ctx (e.g., session + // flows backgrounded on client disconnect). + select { + case <-c.doneCh: + c.mu.Lock() + defer c.mu.Unlock() + return c.output, c.err + default: + } + select { case <-c.doneCh: c.mu.Lock() diff --git a/go/core/api/action.go b/go/core/api/action.go index 0818c527e6..14020d3a5b 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -62,6 +62,8 @@ const ( ActionTypeToolV2 ActionType = "tool.v2" ActionTypeUtil ActionType = "util" ActionTypeCustom ActionType = "custom" + ActionTypeAgentSnapshot ActionType = "agent-snapshot" + ActionTypeAgentAbort ActionType = "agent-abort" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" ActionTypeSessionFlow ActionType = "session-flow" diff --git a/go/core/flow.go b/go/core/flow.go index e10fb34344..35d7c13766 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -189,3 +189,17 @@ func FlowNameFromContext(ctx context.Context) string { } return "" } + +// WithFlowContext attaches flow-context metadata to ctx so that [Run] and +// [FlowNameFromContext] work from within. Use it when wiring a custom +// flow-like action (e.g. via [NewBidiAction] / [DefineBidiAction]) that +// should behave like a flow from the user's perspective — letting them +// call [Run] for sub-step tracking and see the flow name in spans — +// without going through [NewBidiFlow] / [DefineBidiFlow]. +// +// The Define*Flow constructors call this internally; direct callers +// only need it when bypassing those constructors to set custom +// [ActionOptions]. +func WithFlowContext(ctx context.Context, flowName string) context.Context { + return flowContextKey.NewContext(ctx, &flowContext{flowName: flowName}) +} diff --git a/go/core/schemas.config b/go/core/schemas.config index 873e86098c..6c627f02a6 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1203,6 +1203,16 @@ SessionFlowInput doc SessionFlowInput is the input sent to an session flow during a conversation turn. . +SessionFlowInput.detach doc +Detach signals that the client wishes to disconnect after this input is +accepted. The server writes a single pending snapshot (with empty +state), returns [SessionFlowOutput] with that snapshot ID, and +continues processing any already-buffered inputs in a background +context. The pending snapshot is finalized with the cumulative final +state once all queued inputs are processed (or the snapshot is +cancelled via cancelSnapshot). +. + SessionFlowInput.messages type []*ai.Message SessionFlowInput.messages doc Messages contains the user's input for this turn. @@ -1341,7 +1351,7 @@ snapshot was persisted. TurnEnd.snapshotId doc SnapshotID is the ID of the snapshot persisted at the end of this turn. Empty if no snapshot was created (callback returned false or no store -configured). +configured, or snapshots were suspended after detach). . # ---------------------------------------------------------------------------- @@ -1394,6 +1404,72 @@ SnapshotEventInvocationEnd doc InvocationEnd indicates the snapshot was triggered at the end of the invocation. . +SnapshotEventDetach doc +Detach indicates the snapshot was created when the client detached the +invocation and the flow continues in the background. The snapshot is +initially written with [SnapshotStatusPending] and rewritten with a +terminal status once the background work finishes. +. + +# ---------------------------------------------------------------------------- +# Snapshot lifecycle types (hand-written in go/ai/exp/session.go) +# ---------------------------------------------------------------------------- + +# SnapshotStatus enum and the persisted snapshot envelope are hand-written +# alongside the Session type. The companion-action request/response types +# are also hand-written because they reference SessionState[State] with a +# Go type parameter, which the generator does not support. +SnapshotStatus omit +SessionSnapshot omit +SnapshotMetadata omit +GetSnapshotRequest omit +GetSnapshotResponse omit +AbortSnapshotRequest omit +AbortSnapshotResponse omit + +# ---------------------------------------------------------------------------- +# AgentMetadata +# ---------------------------------------------------------------------------- + +AgentMetadata pkg ai/exp + +AgentMetadata doc +AgentMetadata is the value placed under metadata["agent"] on a session +flow's action descriptor. It exposes capability information so the Dev +UI and other reflective callers can render the right surface (e.g. +hide the Abort button when the configured store doesn't support it) +without round-tripping through the reflection API. +. + +AgentMetadata.stateManagement doc +StateManagement reports who owns session state. +. + +AgentMetadata.abortable doc +Abortable reports whether the agent's invocations can be aborted +(true when the store implements [SnapshotAborter]). +. + +AgentMetadataStateManagement pkg ai/exp + +AgentMetadataStateManagement doc +AgentMetadataStateManagement enumerates who owns session state for an +agent: "server" (a [SessionStore] is configured and snapshots are +persisted server-side) or "client" (no store; state flows through +invocation init / output). +. + +AgentMetadataStateManagementServer doc +AgentMetadataStateManagementServer indicates the agent is wired with +a [SessionStore] and persists snapshots server-side. +. + +AgentMetadataStateManagementClient doc +AgentMetadataStateManagementClient indicates the agent has no store; +session state is client-managed and round-trips through invocation +init and output. +. + # ============================================================================ # REFLECTION V2 TYPES (generated into genkit package) # ============================================================================ diff --git a/go/genkit/gen.go b/go/genkit/gen.go index 5cd27b27cb..86b5f2d05b 100644 --- a/go/genkit/gen.go +++ b/go/genkit/gen.go @@ -18,7 +18,9 @@ package genkit -import "encoding/json" +import ( + "encoding/json" +) // ReflectionCancelActionParams is the payload for the "cancelAction" request // sent by the CLI manager to cancel a running action. diff --git a/go/genkit/reflection_test.go b/go/genkit/reflection_test.go index 9f7aafc3bf..8e1944574a 100644 --- a/go/genkit/reflection_test.go +++ b/go/genkit/reflection_test.go @@ -302,22 +302,24 @@ func TestEarlyTraceIDTransmission(t *testing.T) { tc := tracing.NewTestOnlyTelemetryClient() tracing.WriteTelemetryImmediate(tc) - actionStarted := make(chan struct{}) - actionCanProceed := make(chan struct{}) - - // Action that waits for permission to complete - this lets us check headers while it's running - core.DefineAction(g.reg, "test/slow", api.ActionTypeCustom, nil, nil, - func(ctx context.Context, input any) (any, error) { - close(actionStarted) // Signal we've started - <-actionCanProceed // Wait for test to say we can finish - return "completed", nil - }) - s := &reflectionServer{Server: &http.Server{}, activeActions: newActiveActionsMap()} ts := httptest.NewServer(serveMux(g, s)) defer ts.Close() t.Run("headers arrive before body completes", func(t *testing.T) { + // Subtest-local channels and action so the server goroutine for + // this subtest doesn't read variables that the next subtest is + // about to reassign. The previous shared-state setup raced under + // -race because t.Run only synchronizes with the subtest + // goroutine, not with the httptest server's request goroutine. + actionStarted := make(chan struct{}) + actionCanProceed := make(chan struct{}) + core.DefineAction(g.reg, "test/slow", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, input any) (any, error) { + close(actionStarted) + <-actionCanProceed + return "completed", nil + }) // Channel to receive headers as soon as they arrive type headerResult struct { traceID string @@ -375,11 +377,10 @@ func TestEarlyTraceIDTransmission(t *testing.T) { // Backwards compatability t.Run("trace ID in headers matches body", func(t *testing.T) { - // Reset channels for this subtest - actionStarted = make(chan struct{}) - actionCanProceed = make(chan struct{}) - - // Re-register action for this subtest + // Subtest-local channels and action; see the comment on the + // previous subtest. + actionStarted := make(chan struct{}) + actionCanProceed := make(chan struct{}) core.DefineAction(g.reg, "test/slow2", api.ActionTypeCustom, nil, nil, func(ctx context.Context, input any) (any, error) { close(actionStarted) From ab9de2e8e0145d103f3eed5193a3b27ebcccca95 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 11 May 2026 18:35:50 -0500 Subject: [PATCH 062/141] refactor(go/exp): rename session flow to agent and unify definition paths (#5234) --- genkit-tools/common/src/types/agent.ts | 60 +- genkit-tools/genkit-schema.json | 173 ++- go/ai/exp/{session_flow.go => agent.go} | 1009 +++++++++-------- .../{session_flow_test.go => agent_test.go} | 538 +++++---- go/ai/exp/gen.go | 51 +- go/ai/exp/option.go | 78 +- go/ai/exp/session.go | 56 +- go/ai/option.go | 10 + go/core/action.go | 5 + go/core/api/action.go | 2 +- go/core/schemas.config | 125 +- go/genkit/genkit.go | 190 ++-- go/samples/custom-agent/main.go | 12 +- go/samples/prompt-agent/main.go | 18 +- .../genkit/src/genkit/_core/_typing.py | 89 +- 15 files changed, 1250 insertions(+), 1166 deletions(-) rename go/ai/exp/{session_flow.go => agent.go} (65%) rename go/ai/exp/{session_flow_test.go => agent_test.go} (83%) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 999cccadcd..5db3c7b774 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -79,19 +79,17 @@ export const SessionStateSchema = z.object({ custom: z.any().optional(), /** Named collections of parts produced during the conversation. */ artifacts: z.array(ArtifactSchema).optional(), - /** Input used for session flows that require input variables. */ - inputVariables: z.any().optional(), }); export type SessionState = z.infer; /** - * Zod schema for session flow input (per-turn). + * Zod schema for agent input (per-turn). */ -export const SessionFlowInputSchema = z.object({ +export const AgentInputSchema = z.object({ /** * Detach signals that the client wishes to disconnect after this input is * accepted. The server writes a single pending snapshot (with empty - * state), returns SessionFlowOutput with that snapshot ID, and continues + * state), returns AgentOutput with that snapshot ID, and continues * processing any already-buffered inputs in a background context. * Streamed chunks emitted after detach are not forwarded over the wire; * only the final cumulative state is captured when the snapshot is @@ -103,34 +101,34 @@ export const SessionFlowInputSchema = z.object({ /** Tool request parts to re-execute interrupted tools. */ toolRestarts: z.array(PartSchema).optional(), }); -export type SessionFlowInput = z.infer; +export type AgentInput = z.infer; /** - * Zod schema for session flow initialization. + * Zod schema for agent initialization. */ -export const SessionFlowInitSchema = z.object({ +export const AgentInitSchema = z.object({ /** Loads state from a persisted snapshot. Mutually exclusive with state. */ snapshotId: z.string().optional(), /** Direct state for the invocation. Mutually exclusive with snapshotId. */ state: SessionStateSchema.optional(), }); -export type SessionFlowInit = z.infer; +export type AgentInit = z.infer; /** - * Zod schema for session flow result. + * Zod schema for agent result. */ -export const SessionFlowResultSchema = z.object({ +export const AgentResultSchema = z.object({ /** Last model response message from the conversation. */ message: MessageSchema.optional(), /** Artifacts produced during the session. */ artifacts: z.array(ArtifactSchema).optional(), }); -export type SessionFlowResult = z.infer; +export type AgentResult = z.infer; /** - * Zod schema for session flow output. + * Zod schema for agent output. */ -export const SessionFlowOutputSchema = z.object({ +export const AgentOutputSchema = z.object({ /** ID of the snapshot created at the end of this invocation. */ snapshotId: z.string().optional(), /** Final conversation state (only when client-managed). */ @@ -140,10 +138,10 @@ export const SessionFlowOutputSchema = z.object({ /** Artifacts produced during the session. */ artifacts: z.array(ArtifactSchema).optional(), }); -export type SessionFlowOutput = z.infer; +export type AgentOutput = z.infer; /** - * Zod schema for the turn-end signal emitted by a session flow. + * Zod schema for the turn-end signal emitted by an agent. * * A TurnEnd value is emitted exactly once per turn, regardless of whether a * snapshot was persisted. Grouping all turn-end signals here lets callers @@ -161,9 +159,9 @@ export const TurnEndSchema = z.object({ export type TurnEnd = z.infer; /** - * Zod schema for session flow stream chunk. + * Zod schema for agent stream chunk. */ -export const SessionFlowStreamChunkSchema = z.object({ +export const AgentStreamChunkSchema = z.object({ /** Generation tokens from the model. */ modelChunk: ModelResponseChunkSchema.optional(), /** User-defined structured status information. */ @@ -171,15 +169,13 @@ export const SessionFlowStreamChunkSchema = z.object({ /** A newly produced artifact. */ artifact: ArtifactSchema.optional(), /** - * Non-null when the session flow has finished processing the current - * input. Groups all turn-end signals; the client should stop iterating and - * may send the next input. + * Non-null when the agent has finished processing the current input. + * Groups all turn-end signals; the client should stop iterating and may + * send the next input. */ turnEnd: TurnEndSchema.optional(), }); -export type SessionFlowStreamChunk = z.infer< - typeof SessionFlowStreamChunkSchema ->; +export type AgentStreamChunk = z.infer; /** * Zod schema for the metadata projection of a session snapshot. It exists @@ -219,9 +215,9 @@ export const SessionSnapshotSchema = SnapshotMetadataSchema.extend({ export type SessionSnapshot = z.infer; /** - * Zod schema for the input of a session flow's `getSnapshot` companion - * action. The action is registered at `{flowName}/getSnapshot` when the - * flow is defined. + * Zod schema for the input of an agent's `getSnapshot` companion action. + * The action is registered at `{agentName}/getSnapshot` when the agent + * is defined. */ export const GetSnapshotRequestSchema = z.object({ /** Identifies the snapshot to fetch. */ @@ -232,7 +228,7 @@ export type GetSnapshotRequest = z.infer; /** * Zod schema for the output of the `getSnapshot` companion action. It is a * client-facing view of the stored snapshot: identifying metadata plus the - * session state, with `WithSnapshotTransform` applied if configured. + * session state, with `WithStateTransform` applied if configured. */ export const GetSnapshotResponseSchema = z.object({ /** Echoes the requested snapshot ID. */ @@ -275,9 +271,7 @@ export const AbortSnapshotResponseSchema = z.object({ */ status: SnapshotStatusSchema.optional(), }); -export type AbortSnapshotResponse = z.infer< - typeof AbortSnapshotResponseSchema ->; +export type AbortSnapshotResponse = z.infer; /** * Who owns session state for an agent. @@ -294,8 +288,8 @@ export type AgentMetadataStateManagement = z.infer< /** * Zod schema for the agent capability metadata placed under - * `metadata.agent` on a session flow's action descriptor. Lets the Dev - * UI and other reflective callers render the right surface (e.g. hide + * `metadata.agent` on an agent's action descriptor. Lets the Dev UI + * and other reflective callers render the right surface (e.g. hide * the Abort button when the configured store doesn't support it) * without round-tripping through the reflection API. */ diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 607237e3a6..69da18df17 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -28,6 +28,39 @@ ], "additionalProperties": false }, + "AgentInit": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + } + }, + "additionalProperties": false + }, + "AgentInput": { + "type": "object", + "properties": { + "detach": { + "type": "boolean" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/$defs/Message" + } + }, + "toolRestarts": { + "type": "array", + "items": { + "$ref": "#/$defs/Part" + } + } + }, + "additionalProperties": false + }, "AgentMetadata": { "type": "object", "properties": { @@ -51,6 +84,58 @@ "client" ] }, + "AgentOutput": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + }, + "message": { + "$ref": "#/$defs/Message" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + } + }, + "additionalProperties": false + }, + "AgentResult": { + "type": "object", + "properties": { + "message": { + "$ref": "#/$defs/Message" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + } + }, + "additionalProperties": false + }, + "AgentStreamChunk": { + "type": "object", + "properties": { + "modelChunk": { + "$ref": "#/$defs/ModelResponseChunk" + }, + "status": {}, + "artifact": { + "$ref": "#/$defs/Artifact" + }, + "turnEnd": { + "$ref": "#/$defs/TurnEnd" + } + }, + "additionalProperties": false + }, "Artifact": { "type": "object", "properties": { @@ -112,91 +197,6 @@ ], "additionalProperties": false }, - "SessionFlowInit": { - "type": "object", - "properties": { - "snapshotId": { - "type": "string" - }, - "state": { - "$ref": "#/$defs/SessionState" - } - }, - "additionalProperties": false - }, - "SessionFlowInput": { - "type": "object", - "properties": { - "detach": { - "type": "boolean" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/$defs/Message" - } - }, - "toolRestarts": { - "type": "array", - "items": { - "$ref": "#/$defs/Part" - } - } - }, - "additionalProperties": false - }, - "SessionFlowOutput": { - "type": "object", - "properties": { - "snapshotId": { - "type": "string" - }, - "state": { - "$ref": "#/$defs/SessionState" - }, - "message": { - "$ref": "#/$defs/Message" - }, - "artifacts": { - "type": "array", - "items": { - "$ref": "#/$defs/Artifact" - } - } - }, - "additionalProperties": false - }, - "SessionFlowResult": { - "type": "object", - "properties": { - "message": { - "$ref": "#/$defs/Message" - }, - "artifacts": { - "type": "array", - "items": { - "$ref": "#/$defs/Artifact" - } - } - }, - "additionalProperties": false - }, - "SessionFlowStreamChunk": { - "type": "object", - "properties": { - "modelChunk": { - "$ref": "#/$defs/ModelResponseChunk" - }, - "status": {}, - "artifact": { - "$ref": "#/$defs/Artifact" - }, - "turnEnd": { - "$ref": "#/$defs/TurnEnd" - } - }, - "additionalProperties": false - }, "SessionSnapshot": { "type": "object", "properties": { @@ -248,8 +248,7 @@ "items": { "$ref": "#/$defs/Artifact" } - }, - "inputVariables": {} + } }, "additionalProperties": false }, diff --git a/go/ai/exp/session_flow.go b/go/ai/exp/agent.go similarity index 65% rename from go/ai/exp/session_flow.go rename to go/ai/exp/agent.go index 5f3136b7f3..2616ac258b 100644 --- a/go/ai/exp/session_flow.go +++ b/go/ai/exp/agent.go @@ -24,6 +24,7 @@ import ( "context" "fmt" "iter" + "runtime/debug" "sync" "sync/atomic" @@ -32,39 +33,295 @@ import ( "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" - "github.com/firebase/genkit/go/internal/base" ) -// --- SessionFlow --- +// --- AgentSession --- -// SessionFlowFunc is the function signature for session flows. +// AgentSession extends Session with agent-runtime functionality: +// turn management, snapshot persistence, and input channel handling. +type AgentSession[State any] struct { + *Session[State] + + // InputCh is the channel that delivers per-turn inputs from the client. + // It is consumed automatically by [AgentSession.Run], but is exposed + // for advanced use cases that need direct access to the input stream + // (e.g., custom turn loops or fan-out patterns). + InputCh <-chan *AgentInput + // TurnIndex is the zero-based index of the current conversation turn. + // It is incremented automatically by [AgentSession.Run], but is exposed + // for advanced use cases that need to track or manipulate turn ordering + // directly. + TurnIndex int + + snapshotCallback SnapshotCallback[State] + onEndTurn func(ctx context.Context) + lastSnapshot *SessionSnapshot[State] + lastSnapshotVersion uint64 + collectTurnOutput func() any + + // intake is the source of truth for in-flight tracking, queue state, + // and suspended state. The session consults it via beginTurnEnd (in + // maybeSnapshot) so per-turn snapshot writes and detach captures + // cannot race over the same input. + intake *detachIntake +} + +// parentSnapshotID returns the ID of the most recent snapshot in this +// invocation (used to chain new snapshots via ParentID), or "" if no +// snapshot has been written yet. +func (s *AgentSession[State]) parentSnapshotID() string { + if s.lastSnapshot == nil { + return "" + } + return s.lastSnapshot.SnapshotID +} + +// Run loops over the input channel, calling fn for each turn. Each turn is +// wrapped in a trace span for observability. Input messages are automatically +// added to the session before fn is called. After fn returns successfully, a +// TurnEnd chunk is sent and a snapshot check is triggered. +func (s *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) error) error { + for input := range s.InputCh { + spanMeta := &tracing.SpanMetadata{ + Name: fmt.Sprintf("agent/turn/%d", s.TurnIndex), + Type: "flowStep", + Subtype: "flowStep", + } + _, err := tracing.RunInNewSpan(ctx, spanMeta, input, + func(ctx context.Context, input *AgentInput) (any, error) { + s.AddMessages(input.Messages...) + if err := fn(ctx, input); err != nil { + return nil, err + } + s.onEndTurn(ctx) + s.TurnIndex++ + if s.collectTurnOutput != nil { + return s.collectTurnOutput(), nil + } + return nil, nil + }, + ) + if err != nil { + return err + } + } + return nil +} + +// Result returns an [AgentResult] populated from the current session state: +// the last message in the conversation history and all artifacts. +// It is a convenience for custom agents that don't need to construct the +// result manually. +func (s *AgentSession[State]) Result() *AgentResult { + s.mu.RLock() + defer s.mu.RUnlock() + + result := &AgentResult{} + if msgs := s.state.Messages; len(msgs) > 0 { + result.Message = msgs[len(msgs)-1] + } + if len(s.state.Artifacts) > 0 { + arts := make([]*Artifact, len(s.state.Artifacts)) + copy(arts, s.state.Artifacts) + result.Artifacts = arts + } + return result +} + +// maybeSnapshot creates a snapshot if conditions are met (store configured, +// callback approves, state changed, detach has not suspended snapshots). +// Returns the snapshot ID or empty string. +// +// For turn-end events, the session asks the intake whether snapshots +// have been suspended (i.e. detach has landed). If so, the session skips +// the turn-end snapshot — the pending row already captures the +// invocation and a single finalize rewrite will record the cumulative +// state once the queued inputs drain. +func (s *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { + if event == SnapshotEventTurnEnd && s.intake != nil { + if suspended := s.intake.beginTurnEnd(); suspended { + return "" + } + } + + if s.store == nil { + return "" + } + + s.mu.RLock() + currentVersion := s.version + currentState := s.copyStateLocked() + s.mu.RUnlock() + + // Skip if state hasn't changed since the last snapshot. This avoids + // redundant snapshots, e.g. the invocation-end snapshot after a + // single-turn Run where the turn-end snapshot already captured the + // same state. + if s.lastSnapshot != nil && currentVersion == s.lastSnapshotVersion { + return "" + } + + if s.snapshotCallback != nil { + var prevState *SessionState[State] + if s.lastSnapshot != nil { + prevState = &s.lastSnapshot.State + } + if !s.snapshotCallback(ctx, &SnapshotContext[State]{ + State: ¤tState, + PrevState: prevState, + TurnIndex: s.TurnIndex, + Event: event, + }) { + return "" + } + } + + parentID := s.parentSnapshotID() + + saved, err := s.store.SaveSnapshot(ctx, "", + func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + return &SessionSnapshot[State]{ + ParentID: parentID, + Event: event, + Status: SnapshotStatusComplete, + State: currentState, + }, nil + }) + if err != nil { + // Snapshot persistence is best-effort: a store failure must not + // kill the in-flight turn. Surface enough context in the log + // that the failure is diagnosable without the caller having to + // thread the error back up. + logger.FromContext(ctx).Error("agent: failed to save snapshot", + "parentId", parentID, + "event", event, + "err", err) + return "" + } + + s.lastSnapshot = saved + s.lastSnapshotVersion = currentVersion + return saved.SnapshotID +} + +// --- Responder --- + +// Responder is the output channel for an agent. Artifacts sent through +// it are automatically added to the session before being forwarded to the +// client. +type Responder[Stream any] chan<- *AgentStreamChunk[Stream] + +// SendModelChunk sends a generation chunk (token-level streaming). +func (r Responder[Stream]) SendModelChunk(chunk *ai.ModelResponseChunk) { + r <- &AgentStreamChunk[Stream]{ModelChunk: chunk} +} + +// SendStatus sends a user-defined status update. +func (r Responder[Stream]) SendStatus(status Stream) { + r <- &AgentStreamChunk[Stream]{Status: status} +} + +// SendArtifact sends an artifact to the stream and adds it to the session. +// If an artifact with the same name already exists in the session, it is +// replaced. The session-level side effect happens whether or not detach +// has landed; only the wire forward to the client is suppressed +// post-detach, when there is no longer a client to receive it. +func (r Responder[Stream]) SendArtifact(artifact *Artifact) { + r <- &AgentStreamChunk[Stream]{Artifact: artifact} +} + +// --- Agent --- + +// AgentFunc is the function signature for custom agents. // Type parameters: // - Stream: Type for status updates sent via the responder // - State: Type for user-defined state in snapshots -type SessionFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *SessionRunner[State]) (*SessionFlowResult, error) +type AgentFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *AgentSession[State]) (*AgentResult, error) -// SessionFlow is a bidirectional streaming flow with automatic snapshot management. -type SessionFlow[Stream, State any] struct { - action *core.Action[*SessionFlowInit[State], *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream], *SessionFlowInput] +// Agent is a bidirectional streaming agent with automatic snapshot management. +type Agent[Stream, State any] struct { + action *core.Action[*AgentInit[State], *AgentOutput[State], *AgentStreamChunk[Stream], *AgentInput] } -// DefineSessionFlow creates an SessionFlow with automatic snapshot -// management and registers it. The underlying action is created via -// [core.DefineBidiAction] (rather than [core.DefineBidiFlow]) so the +// DefineAgent defines an agent that wraps a prompt defined inline from the +// given options, and registers both under name. Each turn renders the prompt, +// appends conversation history, calls the model with streaming, and updates +// session state. +// +// opts is a mixed list of [github.com/firebase/genkit/go/ai.PromptOption] +// values (which configure the prompt) and [AgentOption] values (which +// configure the agent itself, e.g., [WithSessionStore]). +// +// State is phantom in the variadic, so it cannot be inferred. Specify [any] +// when no typed Custom state is needed; specify [Foo] when a +// [SessionStore[Foo]] is provided. A mismatch panics at definition time with +// a clear message. +// +// For an agent backed by an existing prompt, use [DefinePromptAgent]. For +// full control over the per-turn loop, use [DefineCustomAgent]. +func DefineAgent[State any]( + r api.Registry, + name string, + opts ...AgentDefineOption[State], +) *Agent[any, State] { + var promptOpts []ai.PromptOption + var agentOpts []AgentOption[State] + for _, opt := range opts { + if ao, ok := opt.(AgentOption[State]); ok { + agentOpts = append(agentOpts, ao) + continue + } + if po, ok := opt.(ai.PromptOption); ok { + promptOpts = append(promptOpts, po) + continue + } + panic(fmt.Sprintf("DefineAgent %q: option of type %T does not match agent State %T (likely a typed AgentOption with a different State than the one declared on DefineAgent)", name, opt, *new(State))) + } + + prompt := ai.DefinePrompt(r, name, promptOpts...) + return DefineCustomAgent(r, name, agentLoop[State](r, prompt, nil), agentOpts...) +} + +// DefinePromptAgent defines an agent backed by a prompt already registered +// with the registry (via [ai.DefinePrompt] or loaded from a .prompt file). +// The agent is registered under the same name as the prompt, sharing its +// namespace. +// +// defaultInput is used to render the prompt on every turn. PromptIn is +// captured for compile-time type checking on defaultInput; it is not +// propagated through the [Agent] type. +// +// For an agent that defines its prompt inline, use [DefineAgent]. For full +// control over the per-turn loop, use [DefineCustomAgent]. +func DefinePromptAgent[State, PromptIn any]( + r api.Registry, + promptName string, + defaultInput PromptIn, + opts ...AgentOption[State], +) *Agent[any, State] { + prompt := ai.LookupPrompt(r, promptName) + if prompt == nil { + panic(fmt.Sprintf("DefinePromptAgent: prompt %q not found", promptName)) + } + return DefineCustomAgent(r, promptName, agentLoop[State](r, prompt, defaultInput), opts...) +} + +// DefineCustomAgent defines an agent with full control over the conversation +// loop and registers it with the registry. The underlying action is created +// via [core.DefineBidiAction] (rather than [core.DefineBidiFlow]) so the // agent capability metadata can be set at construction time — actions -// must be immutable once registered. The flow-context wrapping that -// makes [core.Run] work inside fn is preserved via -// [core.WithFlowContext]. -func DefineSessionFlow[Stream, State any]( +// must be immutable once registered. The flow-context wrapping that makes +// [core.Run] work inside fn is preserved via [core.WithFlowContext]. +func DefineCustomAgent[Stream, State any]( r api.Registry, name string, - fn SessionFlowFunc[Stream, State], - opts ...SessionFlowOption[State], -) *SessionFlow[Stream, State] { - cfg := &sessionFlowOptions[State]{} + fn AgentFunc[Stream, State], + opts ...AgentOption[State], +) *Agent[Stream, State] { + cfg := &agentOptions[State]{} for _, opt := range opts { - if err := opt.applySessionFlow(cfg); err != nil { - panic(fmt.Errorf("DefineSessionFlow %q: %w", name, err)) + if err := opt.applyAgent(cfg); err != nil { + panic(fmt.Errorf("DefineCustomAgent %q: %w", name, err)) } } @@ -74,12 +331,12 @@ func DefineSessionFlow[Stream, State any]( }, func( ctx context.Context, - in *SessionFlowInit[State], - inCh <-chan *SessionFlowInput, - outCh chan<- *SessionFlowStreamChunk[Stream], - ) (*SessionFlowOutput[State], error) { + in *AgentInit[State], + inCh <-chan *AgentInput, + outCh chan<- *AgentStreamChunk[Stream], + ) (*AgentOutput[State], error) { ctx = core.WithFlowContext(ctx, name) - rt, err := newSessionFlowRuntime(ctx, name, cfg, in, inCh, outCh) + rt, err := newAgentRuntime(ctx, name, cfg, in, inCh, outCh) if err != nil { return nil, err } @@ -88,11 +345,11 @@ func DefineSessionFlow[Stream, State any]( registerSnapshotActions(r, name, cfg.store, cfg.transform) - return &SessionFlow[Stream, State]{action: action} + return &Agent[Stream, State]{action: action} } // agentMetadataFor derives the [AgentMetadata] value attached to the -// session flow's action descriptor under the "agent" key. [AgentMetadata] +// agent's action descriptor under the "agent" key. [AgentMetadata] // itself is generated from agent.ts; this constructor is hand-written // because it inspects the configured store's optional capabilities. func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { @@ -108,18 +365,18 @@ func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { } } -// --- sessionFlowRuntime --- +// --- agentRuntime --- -// sessionFlowRuntime owns the per-invocation wiring of a session flow: +// agentRuntime owns the per-invocation wiring of an agent: // session, runner, output router, input intake, and the goroutine that runs -// the user fn. Its methods implement the three terminal paths the flow can +// the user fn. Its methods implement the three terminal paths the agent can // take: detach, fn-completion, and client-cancel. -type sessionFlowRuntime[Stream, State any] struct { +type agentRuntime[Stream, State any] struct { name string - cfg *sessionFlowOptions[State] + cfg *agentOptions[State] session *Session[State] - runner *SessionRunner[State] + sess *AgentSession[State] router *chunkRouter[Stream, State] intake *detachIntake @@ -129,24 +386,24 @@ type sessionFlowRuntime[Stream, State any] struct { // fnDoneResult carries the user fn's return values across the goroutine // boundary that runs it. A named type keeps the channel signatures readable. type fnDoneResult[State any] struct { - result *SessionFlowResult + result *AgentResult err error } -func newSessionFlowRuntime[Stream, State any]( +func newAgentRuntime[Stream, State any]( ctx context.Context, name string, - cfg *sessionFlowOptions[State], - in *SessionFlowInit[State], - inCh <-chan *SessionFlowInput, - outCh chan<- *SessionFlowStreamChunk[Stream], -) (*sessionFlowRuntime[Stream, State], error) { + cfg *agentOptions[State], + in *AgentInit[State], + inCh <-chan *AgentInput, + outCh chan<- *AgentStreamChunk[Stream], +) (*agentRuntime[Stream, State], error) { session, parent, err := loadSession(ctx, in, cfg.store) if err != nil { return nil, err } - rt := &sessionFlowRuntime[Stream, State]{ + rt := &agentRuntime[Stream, State]{ name: name, cfg: cfg, session: session, @@ -155,38 +412,38 @@ func newSessionFlowRuntime[Stream, State any]( fnDone: make(chan fnDoneResult[State], 1), } - rt.runner = &SessionRunner[State]{ + rt.sess = &AgentSession[State]{ Session: session, InputCh: rt.intake.out(), snapshotCallback: cfg.callback, lastSnapshot: parent, intake: rt.intake, } - rt.runner.collectTurnOutput = func() any { return rt.router.collectTurnChunks() } - rt.runner.onEndTurn = rt.emitTurnEnd + rt.sess.collectTurnOutput = func() any { return rt.router.collectTurnChunks() } + rt.sess.onEndTurn = rt.emitTurnEnd return rt, nil } -// emitTurnEnd is called by the runner after each successful turn. It writes +// emitTurnEnd is called by the session after each successful turn. It writes // a turn-end snapshot (if applicable) and forwards the resulting [TurnEnd] // chunk through the router so clients see it on the output stream. -func (rt *sessionFlowRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { - snapshotID := rt.runner.maybeSnapshot(ctx, SnapshotEventTurnEnd) - rt.router.send() <- &SessionFlowStreamChunk[Stream]{TurnEnd: &TurnEnd{ +func (rt *agentRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { + snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd) + rt.router.send() <- &AgentStreamChunk[Stream]{TurnEnd: &TurnEnd{ SnapshotID: snapshotID, }} } -// run drives the user fn to completion and returns the flow output. +// run drives the user fn to completion and returns the agent output. // // workCtx carries the session and is decoupled from clientCtx: pre-detach a // watcher mirrors clientCtx so a disconnect cancels the work; on detach the // watcher exits and the finalizer goroutine owns workCtx until fn returns. -func (rt *sessionFlowRuntime[Stream, State]) run( +func (rt *agentRuntime[Stream, State]) run( clientCtx context.Context, - fn SessionFlowFunc[Stream, State], -) (*SessionFlowOutput[State], error) { + fn AgentFunc[Stream, State], +) (*AgentOutput[State], error) { workCtx, cancelWork := context.WithCancel(context.WithoutCancel(clientCtx)) workCtx = NewSessionContext(workCtx, rt.session) @@ -204,8 +461,23 @@ func (rt *sessionFlowRuntime[Stream, State]) run( }() go func() { - result, err := fn(workCtx, rt.router.responder(), rt.runner) - rt.fnDone <- fnDoneResult[State]{result: result, err: err} + // Run fn under deferred panic recovery so a panic surfaces as + // an error rather than crashing the process or leaking the + // fnDone channel. + var ( + result *AgentResult + fnErr error + ) + func() { + defer func() { + if r := recover(); r != nil { + logger.FromContext(workCtx).Error("agent fn panicked", "panic", r, "stack", string(debug.Stack())) + fnErr = core.NewError(core.INTERNAL, "agent fn panicked: %v", r) + } + }() + result, fnErr = fn(workCtx, rt.router.responder(), rt.sess) + }() + rt.fnDone <- fnDoneResult[State]{result: result, err: fnErr} }() select { @@ -233,23 +505,30 @@ func (rt *sessionFlowRuntime[Stream, State]) run( // pending snapshot) and a [SnapshotAborter] (which bundles both abort // triggering and status-change subscription so the runtime can react to // the abort without polling). -func (rt *sessionFlowRuntime[Stream, State]) checkDetachCapabilities() error { +func (rt *agentRuntime[Stream, State]) checkDetachCapabilities() error { if rt.cfg.store == nil { return core.NewError(core.FAILED_PRECONDITION, - "session flow %q: detach requires a session store", rt.name) + "agent %q: detach requires a session store", rt.name) } if _, ok := rt.cfg.store.(SnapshotAborter); !ok { return core.NewError(core.FAILED_PRECONDITION, - "session flow %q: detach requires a session store implementing SnapshotAborter", rt.name) + "agent %q: detach requires a session store implementing SnapshotAborter", rt.name) } return nil } -// drainAndWait performs a synchronous shutdown: cancel work, wait for the -// intake reader/forwarder to finish, drain fnDone, and close the router. -// Returns the fn's result for callers that need to surface its error. -func (rt *sessionFlowRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFunc) fnDoneResult[State] { +// drainAndWait performs a synchronous shutdown: cancel work, stop router +// writes (so a fn mid-send doesn't deadlock once outCh's consumer is +// gone), wait for the intake reader/forwarder to finish, drain fnDone, +// and close the router. Returns the fn's result for callers that need +// to surface its error. +func (rt *agentRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFunc) fnDoneResult[State] { cancelWork() + // Switch the router to side-effects-only mode before waiting on fn. + // Without this, a fn mid-SendStatus blocks on the router's r.in + // receive while the router blocks on r.out send (consumer is gone), + // so fn never observes ctx and we deadlock waiting on fnDone. + rt.router.stopAndWait() rt.intake.stopAndWait() res := <-rt.fnDone rt.router.close() @@ -259,11 +538,11 @@ func (rt *sessionFlowRuntime[Stream, State]) drainAndWait(cancelWork context.Can // handleFnDone is the synchronous-completion path: fn returned before any // detach signal. Capture an invocation-end snapshot if state advanced past // the last turn-end snapshot, then assemble the output. -func (rt *sessionFlowRuntime[Stream, State]) handleFnDone( +func (rt *agentRuntime[Stream, State]) handleFnDone( ctx context.Context, cancelWork context.CancelFunc, res fnDoneResult[State], -) (*SessionFlowOutput[State], error) { +) (*AgentOutput[State], error) { cancelWork() rt.intake.stopAndWait() rt.router.close() @@ -272,14 +551,14 @@ func (rt *sessionFlowRuntime[Stream, State]) handleFnDone( return nil, res.err } - snapshotID := rt.runner.maybeSnapshot(ctx, SnapshotEventInvocationEnd) - if snapshotID == "" && rt.runner.lastSnapshot != nil { + snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + if snapshotID == "" && rt.sess.lastSnapshot != nil { // State unchanged since the last turn-end snapshot — reuse it so // the response always carries an ID when a store is configured. - snapshotID = rt.runner.lastSnapshot.SnapshotID + snapshotID = rt.sess.lastSnapshot.SnapshotID } - out := &SessionFlowOutput[State]{SnapshotID: snapshotID} + out := &AgentOutput[State]{SnapshotID: snapshotID} if res.result != nil { out.Message = res.result.Message out.Artifacts = res.result.Artifacts @@ -297,18 +576,18 @@ func (rt *sessionFlowRuntime[Stream, State]) handleFnDone( // stops writing to outCh but keeps applying in-process side effects // (e.g. artifacts added via Responder.SendArtifact) so user code does // not have to branch on detach. -func (rt *sessionFlowRuntime[Stream, State]) handleDetach( +func (rt *agentRuntime[Stream, State]) handleDetach( clientCtx, workCtx context.Context, cancelWork context.CancelFunc, markDetached func(), -) (*SessionFlowOutput[State], error) { +) (*AgentOutput[State], error) { // Stop mirroring clientCtx. From here, only the abort subscription or // fn completion can cancel workCtx. markDetached() rt.intake.suspend() - parentID := rt.runner.parentSnapshotID() + parentID := rt.sess.parentSnapshotID() // Detach intends to outlive the client connection. If clientCtx was // already cancelled (or cancels mid-write), we still want the pending @@ -324,7 +603,7 @@ func (rt *sessionFlowRuntime[Stream, State]) handleDetach( if err != nil { rt.drainAndWait(cancelWork) return nil, core.NewError(core.INTERNAL, - "session flow %q: detach: save pending snapshot: %v", rt.name, err) + "agent %q: detach: save pending snapshot: %v", rt.name, err) } // The router can no longer write to outCh once we return; the bidi @@ -356,7 +635,7 @@ func (rt *sessionFlowRuntime[Stream, State]) handleDetach( cancelWork() }() - return &SessionFlowOutput[State]{SnapshotID: pending.SnapshotID}, nil + return &AgentOutput[State]{SnapshotID: pending.SnapshotID}, nil } // finalizePendingSnapshot rewrites the pending snapshot row with the @@ -366,7 +645,7 @@ func (rt *sessionFlowRuntime[Stream, State]) handleDetach( // so the read-and-rewrite is one atomic step: if the row has already // transitioned to canceled (a late abort racing this finalize), // SaveSnapshot sees it inside fn and we leave the row untouched. -func (rt *sessionFlowRuntime[Stream, State]) finalizePendingSnapshot( +func (rt *agentRuntime[Stream, State]) finalizePendingSnapshot( ctx context.Context, pending *SessionSnapshot[State], fnErr error, @@ -403,7 +682,7 @@ func (rt *sessionFlowRuntime[Stream, State]) finalizePendingSnapshot( }, nil }) if err != nil { - logger.FromContext(ctx).Error("session flow: failed to finalize pending snapshot", + logger.FromContext(ctx).Error("agent: failed to finalize pending snapshot", "snapshotId", pending.SnapshotID, "err", err) } } @@ -413,7 +692,7 @@ func (rt *sessionFlowRuntime[Stream, State]) finalizePendingSnapshot( // snapshot too so the runtime can chain ParentID off it. func loadSession[State any]( ctx context.Context, - init *SessionFlowInit[State], + init *AgentInit[State], store SessionStore[State], ) (*Session[State], *SessionSnapshot[State], error) { s := &Session[State]{store: store} @@ -462,194 +741,6 @@ func loadSession[State any]( return s, snap, nil } -// --- SessionRunner --- - -// SessionRunner extends Session with session-flow-specific functionality: -// turn management, snapshot persistence, and input channel handling. -type SessionRunner[State any] struct { - *Session[State] - - // InputCh is the channel that delivers per-turn inputs from the client. - // It is consumed automatically by [SessionRunner.Run], but is exposed - // for advanced use cases that need direct access to the input stream - // (e.g., custom turn loops or fan-out patterns). - InputCh <-chan *SessionFlowInput - // TurnIndex is the zero-based index of the current conversation turn. - // It is incremented automatically by [SessionRunner.Run], but is exposed - // for advanced use cases that need to track or manipulate turn ordering - // directly. - TurnIndex int - - snapshotCallback SnapshotCallback[State] - onEndTurn func(ctx context.Context) - lastSnapshot *SessionSnapshot[State] - lastSnapshotVersion uint64 - collectTurnOutput func() any - - // intake is the source of truth for in-flight tracking, queue state, - // and suspended state. The runner consults it via beginTurnEnd (in - // maybeSnapshot) so per-turn snapshot writes and detach captures - // cannot race over the same input. - intake *detachIntake -} - -// parentSnapshotID returns the ID of the most recent snapshot in this -// invocation (used to chain new snapshots via ParentID), or "" if no -// snapshot has been written yet. -func (s *SessionRunner[State]) parentSnapshotID() string { - if s.lastSnapshot == nil { - return "" - } - return s.lastSnapshot.SnapshotID -} - -// Run loops over the input channel, calling fn for each turn. Each turn is -// wrapped in a trace span for observability. Input messages are automatically -// added to the session before fn is called. After fn returns successfully, a -// TurnEnd chunk is sent and a snapshot check is triggered. -func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error { - for input := range s.InputCh { - spanMeta := &tracing.SpanMetadata{ - Name: fmt.Sprintf("sessionFlow/turn/%d", s.TurnIndex), - Type: "flowStep", - Subtype: "flowStep", - } - _, err := tracing.RunInNewSpan(ctx, spanMeta, input, - func(ctx context.Context, input *SessionFlowInput) (any, error) { - s.AddMessages(input.Messages...) - if err := fn(ctx, input); err != nil { - return nil, err - } - s.onEndTurn(ctx) - s.TurnIndex++ - if s.collectTurnOutput != nil { - return s.collectTurnOutput(), nil - } - return nil, nil - }, - ) - if err != nil { - return err - } - } - return nil -} - -// Result returns an [SessionFlowResult] populated from the current session state: -// the last message in the conversation history and all artifacts. -// It is a convenience for custom session flows that don't need to construct the -// result manually. -func (s *SessionRunner[State]) Result() *SessionFlowResult { - s.mu.RLock() - defer s.mu.RUnlock() - - result := &SessionFlowResult{} - if msgs := s.state.Messages; len(msgs) > 0 { - result.Message = msgs[len(msgs)-1] - } - if len(s.state.Artifacts) > 0 { - arts := make([]*Artifact, len(s.state.Artifacts)) - copy(arts, s.state.Artifacts) - result.Artifacts = arts - } - return result -} - -// maybeSnapshot creates a snapshot if conditions are met (store configured, -// callback approves, state changed, detach has not suspended snapshots). -// Returns the snapshot ID or empty string. -// -// For turn-end events, the runner asks the intake whether snapshots -// have been suspended (i.e. detach has landed). If so, the runner skips -// the turn-end snapshot — the pending row already captures the -// invocation and a single finalize rewrite will record the cumulative -// state once the queued inputs drain. -func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { - if event == SnapshotEventTurnEnd && s.intake != nil { - if suspended := s.intake.beginTurnEnd(); suspended { - return "" - } - } - - if s.store == nil { - return "" - } - - s.mu.RLock() - currentVersion := s.version - currentState := s.copyStateLocked() - s.mu.RUnlock() - - // Skip if state hasn't changed since the last snapshot. This avoids - // redundant snapshots, e.g. the invocation-end snapshot after a - // single-turn Run where the turn-end snapshot already captured the - // same state. - if s.lastSnapshot != nil && currentVersion == s.lastSnapshotVersion { - return "" - } - - if s.snapshotCallback != nil { - var prevState *SessionState[State] - if s.lastSnapshot != nil { - prevState = &s.lastSnapshot.State - } - if !s.snapshotCallback(ctx, &SnapshotContext[State]{ - State: ¤tState, - PrevState: prevState, - TurnIndex: s.TurnIndex, - Event: event, - }) { - return "" - } - } - - parentID := s.parentSnapshotID() - - saved, err := s.store.SaveSnapshot(ctx, "", - func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { - return &SessionSnapshot[State]{ - ParentID: parentID, - Event: event, - Status: SnapshotStatusComplete, - State: currentState, - }, nil - }) - if err != nil { - logger.FromContext(ctx).Error("session flow: failed to save snapshot", "err", err) - return "" - } - - s.lastSnapshot = saved - s.lastSnapshotVersion = currentVersion - return saved.SnapshotID -} - -// --- Responder --- - -// Responder is the output channel for an session flow. Artifacts sent through -// it are automatically added to the session before being forwarded to the -// client. -type Responder[Stream any] chan<- *SessionFlowStreamChunk[Stream] - -// SendModelChunk sends a generation chunk (token-level streaming). -func (r Responder[Stream]) SendModelChunk(chunk *ai.ModelResponseChunk) { - r <- &SessionFlowStreamChunk[Stream]{ModelChunk: chunk} -} - -// SendStatus sends a user-defined status update. -func (r Responder[Stream]) SendStatus(status Stream) { - r <- &SessionFlowStreamChunk[Stream]{Status: status} -} - -// SendArtifact sends an artifact to the stream and adds it to the session. -// If an artifact with the same name already exists in the session, it is -// replaced. The session-level side effect happens whether or not detach -// has landed; only the wire forward to the client is suppressed -// post-detach, when there is no longer a client to receive it. -func (r Responder[Stream]) SendArtifact(artifact *Artifact) { - r <- &SessionFlowStreamChunk[Stream]{Artifact: artifact} -} - // --- chunkRouter --- // // chunkRouter owns the intermediate stream channel that all chunks flow @@ -663,12 +754,12 @@ func (r Responder[Stream]) SendArtifact(artifact *Artifact) { // send. type chunkRouter[Stream, State any] struct { - in chan *SessionFlowStreamChunk[Stream] - out chan<- *SessionFlowStreamChunk[Stream] + in chan *AgentStreamChunk[Stream] + out chan<- *AgentStreamChunk[Stream] session *Session[State] turnMu sync.Mutex - turnChunks []*SessionFlowStreamChunk[Stream] + turnChunks []*AgentStreamChunk[Stream] done chan struct{} stopWriting chan struct{} @@ -677,10 +768,10 @@ type chunkRouter[Stream, State any] struct { func startChunkRouter[Stream, State any]( session *Session[State], - out chan<- *SessionFlowStreamChunk[Stream], + out chan<- *AgentStreamChunk[Stream], ) *chunkRouter[Stream, State] { r := &chunkRouter[Stream, State]{ - in: make(chan *SessionFlowStreamChunk[Stream]), + in: make(chan *AgentStreamChunk[Stream]), out: out, session: session, done: make(chan struct{}), @@ -709,7 +800,7 @@ func (r *chunkRouter[Stream, State]) run() { // applySideEffects records the chunk's effect on session state and turn // span output. Invoked from both forward (pre-detach) and the post-detach // drain so a Send call is observably the same in either mode. -func (r *chunkRouter[Stream, State]) applySideEffects(chunk *SessionFlowStreamChunk[Stream]) { +func (r *chunkRouter[Stream, State]) applySideEffects(chunk *AgentStreamChunk[Stream]) { if chunk.Artifact != nil { r.session.AddArtifacts(chunk.Artifact) } @@ -747,13 +838,13 @@ func (r *chunkRouter[Stream, State]) responder() Responder[Stream] { } // send returns the internal chunk channel for producers other than the user -// flow function (e.g. the runtime's emitTurnEnd). -func (r *chunkRouter[Stream, State]) send() chan<- *SessionFlowStreamChunk[Stream] { +// agent function (e.g. the runtime's emitTurnEnd). +func (r *chunkRouter[Stream, State]) send() chan<- *AgentStreamChunk[Stream] { return r.in } // collectTurnChunks returns and resets accumulated turn chunks. -func (r *chunkRouter[Stream, State]) collectTurnChunks() []*SessionFlowStreamChunk[Stream] { +func (r *chunkRouter[Stream, State]) collectTurnChunks() []*AgentStreamChunk[Stream] { r.turnMu.Lock() defer r.turnMu.Unlock() result := r.turnChunks @@ -782,9 +873,9 @@ func (r *chunkRouter[Stream, State]) close() { // // The reader goroutine pulls from the bidi framework's inCh as soon as // inputs arrive and appends them to an internal queue. This is what makes -// detach detection immediate: the moment an input with -// [SessionFlowInput.Detach] lands in src, the reader sees it without -// waiting for the runner to finish whatever it's processing. +// detach detection immediate: the moment an input with [AgentInput.Detach] +// lands in src, the reader sees it without waiting for the runner to +// finish whatever it's processing. // // The forwarder goroutine pops the queue and writes to dst, blocking on // the runner via turnDone so it stays in step with turn pacing. @@ -800,8 +891,8 @@ func (r *chunkRouter[Stream, State]) close() { // beginTurnEnd that returns after suspend completes sees suspended=true. type detachIntake struct { - src <-chan *SessionFlowInput - dst chan *SessionFlowInput + src <-chan *AgentInput + dst chan *AgentInput notify chan struct{} // buffered size 1; wakes forwarder when queue grows // turnDone is signaled by beginTurnEnd to release the forwarder so it @@ -811,7 +902,7 @@ type detachIntake struct { mu sync.Mutex suspended bool - queue []*SessionFlowInput + queue []*AgentInput readDone atomic.Bool detachCh chan struct{} // signaled by reader when detach observed @@ -821,10 +912,10 @@ type detachIntake struct { done chan struct{} } -func startDetachIntake(src <-chan *SessionFlowInput) *detachIntake { +func startDetachIntake(src <-chan *AgentInput) *detachIntake { i := &detachIntake{ src: src, - dst: make(chan *SessionFlowInput), + dst: make(chan *AgentInput), notify: make(chan struct{}, 1), turnDone: make(chan struct{}, 1), detachCh: make(chan struct{}, 1), @@ -886,7 +977,7 @@ func (i *detachIntake) read() { } } -func (i *detachIntake) enqueue(input *SessionFlowInput) { +func (i *detachIntake) enqueue(input *AgentInput) { i.mu.Lock() i.queue = append(i.queue, input) i.mu.Unlock() @@ -901,9 +992,9 @@ func (i *detachIntake) enqueue(input *SessionFlowInput) { // than enqueued: it carries no payload to process, so it would just // trigger a no-op turn. Callers that want to ride a final input on the // detach signal can do so by calling -// Send(&SessionFlowInput{Detach: true, Messages: ...}) explicitly. -func (i *detachIntake) handleDetach(first *SessionFlowInput) { - var drained []*SessionFlowInput +// Send(&AgentInput{Detach: true, Messages: ...}) explicitly. +func (i *detachIntake) handleDetach(first *AgentInput) { + var drained []*AgentInput if hasInputPayload(first) { drained = append(drained, first) } @@ -936,7 +1027,7 @@ drainLoop: // hasInputPayload reports whether the input carries data the runner would // otherwise process. Used to filter pure detach signals out of the // queue so they don't trigger no-op turns. -func hasInputPayload(in *SessionFlowInput) bool { +func hasInputPayload(in *AgentInput) bool { return in != nil && (len(in.Messages) > 0 || len(in.ToolRestarts) > 0) } @@ -970,7 +1061,7 @@ func (i *detachIntake) forward() { // awaitInput blocks until the queue has an input, the reader is done, or // stop is signaled. Returns the popped input or nil if no further inputs // will arrive. -func (i *detachIntake) awaitInput() *SessionFlowInput { +func (i *detachIntake) awaitInput() *AgentInput { for { i.mu.Lock() if len(i.queue) > 0 { @@ -1002,7 +1093,7 @@ func (i *detachIntake) releaseForward() { } } -func (i *detachIntake) out() <-chan *SessionFlowInput { +func (i *detachIntake) out() <-chan *AgentInput { return i.dst } @@ -1010,7 +1101,7 @@ func (i *detachIntake) detachSignal() <-chan struct{} { return i.detachCh } -// beginTurnEnd is called by [SessionRunner.maybeSnapshot] before writing +// beginTurnEnd is called by [AgentSession.maybeSnapshot] before writing // a turn-end snapshot. If the intake has been suspended (detach landed), // it returns suspended=true and the runner skips the snapshot. // @@ -1042,40 +1133,129 @@ func (i *detachIntake) stopAndWait() { <-i.done } -// --- SessionFlow client API --- +// promptMessageKey is the metadata key used to tag base messages from the +// agent config (system prompt, prompt template output, etc.) so they can be +// excluded from session history after generation. +const promptMessageKey = "_genkit_prompt" -// StreamBidi starts a new session flow invocation with bidirectional streaming. +// agentLoop returns the per-turn function for a prompt-backed agent. Each +// turn renders the prompt, appends conversation history, calls the model +// with streaming, and updates the session. +// +// defaultInput is the prompt input passed to Render on every turn. It is +// nil for [DefineAgent], where the inline-defined prompt has no per-turn +// input. +func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[any, State] { + return func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) (*AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + actionOpts, err := prompt.Render(ctx, defaultInput) + if err != nil { + return fmt.Errorf("prompt render: %w", err) + } + + // Tag base messages so they can be filtered out of session + // history after generation. + for _, m := range actionOpts.Messages { + if m.Metadata == nil { + m.Metadata = make(map[string]any) + } + m.Metadata[promptMessageKey] = true + } + + // Append conversation history after the base messages. + actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) + + // If tool restarts were provided, set the resume option so + // handleResumeOption re-executes the interrupted tools. + if len(input.ToolRestarts) > 0 { + for _, p := range input.ToolRestarts { + if !p.IsToolRequest() { + return core.NewError(core.INVALID_ARGUMENT, "ToolRestarts: part is not a tool request") + } + } + actionOpts.Resume = ai.NewResume(input.ToolRestarts, nil) + } + + modelResp, err := ai.GenerateWithRequest(ctx, r, actionOpts, nil, + func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + resp.SendModelChunk(chunk) + return nil + }, + ) + if err != nil { + return fmt.Errorf("generate: %w", err) + } + + // Replace session messages with the full history minus base + // messages. This captures intermediate tool call/response + // messages from the tool loop, not just the final response. + if modelResp.Request != nil { + history := modelResp.History() + msgs := make([]*ai.Message, 0, len(history)) + for _, m := range history { + if m.Metadata != nil && m.Metadata[promptMessageKey] == true { + continue + } + msgs = append(msgs, m) + } + sess.SetMessages(msgs) + } else if modelResp.Message != nil { + sess.AddMessages(modelResp.Message) + } + + // Stream interrupt parts so the client can detect and + // handle them (e.g. prompt the user for confirmation). + if modelResp.FinishReason == ai.FinishReasonInterrupted { + if parts := modelResp.Interrupts(); len(parts) > 0 { + resp.SendModelChunk(&ai.ModelResponseChunk{ + Role: ai.RoleTool, + Content: parts, + }) + } + } + + return nil + }); err != nil { + return nil, err + } + return sess.Result(), nil + } +} + +// --- Agent client API --- + +// StreamBidi starts a new agent invocation with bidirectional streaming. // Use this for multi-turn interactions where you need to send multiple inputs // and receive streaming chunks. For single-turn usage, see Run and RunText. -func (af *SessionFlow[Stream, State]) StreamBidi( +func (a *Agent[Stream, State]) StreamBidi( ctx context.Context, opts ...InvocationOption[State], -) (*SessionFlowConnection[Stream, State], error) { - init, err := af.resolveOptions(opts) +) (*AgentConnection[Stream, State], error) { + init, err := a.resolveOptions(opts) if err != nil { return nil, err } - conn, err := af.action.StreamBidi(ctx, init) + conn, err := a.action.StreamBidi(ctx, init) if err != nil { return nil, err } - return &SessionFlowConnection[Stream, State]{conn: conn}, nil + return &AgentConnection[Stream, State]{conn: conn}, nil } -// Run starts a single-turn session flow invocation with the given input. -// It sends the input, waits for the flow to complete, and returns the output. +// Run starts a single-turn agent invocation with the given input. +// It sends the input, waits for the agent to complete, and returns the output. // For multi-turn interactions or streaming, use StreamBidi instead. -func (af *SessionFlow[Stream, State]) Run( +func (a *Agent[Stream, State]) Run( ctx context.Context, - input *SessionFlowInput, + input *AgentInput, opts ...InvocationOption[State], -) (*SessionFlowOutput[State], error) { - conn, err := af.StreamBidi(ctx, opts...) +) (*AgentOutput[State], error) { + conn, err := a.StreamBidi(ctx, opts...) if err != nil { return nil, err } // If the bidi function fails fast (e.g. resuming from an errored - // snapshot rejects in newSessionFlowRuntime), Send / Close / Receive + // snapshot rejects in newAgentRuntime), Send / Close / Receive // see a closed connection and return generic "action has completed" // errors. The real fn error is on Output(). Prefer it whenever it's // non-nil so callers get the meaningful failure. @@ -1102,51 +1282,45 @@ func (af *SessionFlow[Stream, State]) Run( return conn.Output() } -// RunText is a convenience method that starts a single-turn session flow -// invocation with a user text message. It is equivalent to calling Run with -// an SessionFlowInput containing a single user text message. -func (af *SessionFlow[Stream, State]) RunText( +// RunText is a convenience method that starts a single-turn agent invocation +// with a user text message. It is equivalent to calling Run with an +// AgentInput containing a single user text message. +func (a *Agent[Stream, State]) RunText( ctx context.Context, text string, opts ...InvocationOption[State], -) (*SessionFlowOutput[State], error) { - return af.Run(ctx, &SessionFlowInput{ +) (*AgentOutput[State], error) { + return a.Run(ctx, &AgentInput{ Messages: []*ai.Message{ai.NewUserTextMessage(text)}, }, opts...) } // resolveOptions applies invocation options and returns the init struct. -func (af *SessionFlow[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*SessionFlowInit[State], error) { - cfg := &invocationOptions[State]{} +func (a *Agent[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*AgentInit[State], error) { + invOpts := &invocationOptions[State]{} for _, opt := range opts { - if err := opt.applyInvocation(cfg); err != nil { - return nil, fmt.Errorf("SessionFlow %q: %w", af.action.Name(), err) + if err := opt.applyInvocation(invOpts); err != nil { + return nil, fmt.Errorf("Agent %q: %w", a.action.Name(), err) } } - init := &SessionFlowInit[State]{ - SnapshotID: cfg.snapshotID, - State: cfg.state, - } - if cfg.promptInput != nil { - if init.State == nil { - init.State = &SessionState[State]{} - } - init.State.InputVariables = cfg.promptInput - } - return init, nil + + return &AgentInit[State]{ + SnapshotID: invOpts.snapshotID, + State: invOpts.state, + }, nil } -// --- SessionFlowConnection --- +// --- AgentConnection --- -// SessionFlowConnection wraps BidiConnection with session flow-specific functionality. -// It provides a Receive() iterator that supports multi-turn patterns: breaking out -// of the iterator between turns does not cancel the underlying connection. -type SessionFlowConnection[Stream, State any] struct { - conn *core.BidiConnection[*SessionFlowInput, *SessionFlowStreamChunk[Stream], *SessionFlowOutput[State]] +// AgentConnection wraps BidiConnection with agent-specific functionality. +// It provides a Receive() iterator that supports multi-turn patterns: breaking +// out of the iterator between turns does not cancel the underlying connection. +type AgentConnection[Stream, State any] struct { + conn *core.BidiConnection[*AgentInput, *AgentStreamChunk[Stream], *AgentOutput[State]] // chunks buffers stream chunks from the underlying connection so that // breaking from Receive() between turns doesn't cancel the context. - chunks chan *SessionFlowStreamChunk[Stream] + chunks chan *AgentStreamChunk[Stream] chunkErr error initOnce sync.Once } @@ -1154,9 +1328,9 @@ type SessionFlowConnection[Stream, State any] struct { // initReceiver starts a goroutine that drains the underlying BidiConnection's // Receive into a channel. This goroutine never breaks from the underlying // iterator, preventing context cancellation. -func (c *SessionFlowConnection[Stream, State]) initReceiver() { +func (c *AgentConnection[Stream, State]) initReceiver() { c.initOnce.Do(func() { - c.chunks = make(chan *SessionFlowStreamChunk[Stream], 1) + c.chunks = make(chan *AgentStreamChunk[Stream], 1) go func() { defer close(c.chunks) for chunk, err := range c.conn.Receive() { @@ -1170,27 +1344,27 @@ func (c *SessionFlowConnection[Stream, State]) initReceiver() { }) } -// Send sends an SessionFlowInput to the session flow. -func (c *SessionFlowConnection[Stream, State]) Send(input *SessionFlowInput) error { +// Send sends an AgentInput to the agent. +func (c *AgentConnection[Stream, State]) Send(input *AgentInput) error { return c.conn.Send(input) } -// SendMessages sends messages to the session flow. -func (c *SessionFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { - return c.conn.Send(&SessionFlowInput{Messages: messages}) +// SendMessages sends messages to the agent. +func (c *AgentConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { + return c.conn.Send(&AgentInput{Messages: messages}) } -// SendText sends a single user text message to the session flow. -func (c *SessionFlowConnection[Stream, State]) SendText(text string) error { - return c.conn.Send(&SessionFlowInput{ +// SendText sends a single user text message to the agent. +func (c *AgentConnection[Stream, State]) SendText(text string) error { + return c.conn.Send(&AgentInput{ Messages: []*ai.Message{ai.NewUserTextMessage(text)}, }) } // SendToolRestarts sends tool restart parts to resume interrupted tool calls. // Parts should be created via [ai.ToolDef.RestartWith]. -func (c *SessionFlowConnection[Stream, State]) SendToolRestarts(parts ...*ai.Part) error { - return c.conn.Send(&SessionFlowInput{ToolRestarts: parts}) +func (c *AgentConnection[Stream, State]) SendToolRestarts(parts ...*ai.Part) error { + return c.conn.Send(&AgentInput{ToolRestarts: parts}) } // Detach asks the server to write a pending snapshot, close the @@ -1207,13 +1381,13 @@ func (c *SessionFlowConnection[Stream, State]) SendToolRestarts(parts ...*ai.Par // session and end up in the final snapshot's state. // // To send a final input as part of the same wire message, use -// Send(&SessionFlowInput{Detach: true, Messages: ...}) directly. -func (c *SessionFlowConnection[Stream, State]) Detach() error { - return c.conn.Send(&SessionFlowInput{Detach: true}) +// Send(&AgentInput{Detach: true, Messages: ...}) directly. +func (c *AgentConnection[Stream, State]) Detach() error { + return c.conn.Send(&AgentInput{Detach: true}) } // Close signals that no more inputs will be sent. -func (c *SessionFlowConnection[Stream, State]) Close() error { +func (c *AgentConnection[Stream, State]) Close() error { return c.conn.Close() } @@ -1221,9 +1395,9 @@ func (c *SessionFlowConnection[Stream, State]) Close() error { // Unlike the underlying BidiConnection.Receive, breaking out of this iterator // does not cancel the connection. This enables multi-turn patterns where the // caller breaks on TurnEnd, sends the next input, then calls Receive again. -func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { +func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[Stream], error] { c.initReceiver() - return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { + return func(yield func(*AgentStreamChunk[Stream], error) bool) { for chunk := range c.chunks { if !yield(chunk, nil) { return @@ -1235,148 +1409,19 @@ func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowS } } -// Output returns the final response after the session flow completes. +// Output returns the final response after the agent completes. // -// Unlike the underlying BidiConnection, Output waits for the flow to +// Unlike the underlying BidiConnection, Output waits for the agent to // finalize before returning. This is important for detached invocations: -// when the client sends Detach, the flow function returns promptly with a +// when the client sends Detach, the agent function returns promptly with a // pending snapshot ID, and callers need to observe that output rather than // the context cancellation error. -func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[State], error) { +func (c *AgentConnection[Stream, State]) Output() (*AgentOutput[State], error) { <-c.conn.Done() return c.conn.Output() } // Done returns a channel closed when the connection completes. -func (c *SessionFlowConnection[Stream, State]) Done() <-chan struct{} { +func (c *AgentConnection[Stream, State]) Done() <-chan struct{} { return c.conn.Done() } - -// --- DefineSessionFlowFromPrompt --- - -// promptMessageKey tags prompt-rendered messages so they can be excluded -// from session history after generation. They're rendered fresh each turn -// from the registered prompt, so persisting them in history would cause -// duplication on resume. -const promptMessageKey = "_genkit_prompt" - -// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an -// automatic conversation loop. Each turn renders the prompt, appends -// conversation history, calls GenerateWithRequest, streams chunks to the -// client, and adds the model response to the session. -// -// The prompt is looked up by name from the registry using -// [ai.LookupDataPrompt]. The defaultInput is used for prompt rendering -// unless overridden per invocation via WithInputVariables. -func DefineSessionFlowFromPrompt[State, PromptIn any]( - r api.Registry, - promptName string, - defaultInput PromptIn, - opts ...SessionFlowOption[State], -) *SessionFlow[any, State] { - p := ai.LookupDataPrompt[PromptIn, string](r, promptName) - if p == nil { - panic(fmt.Sprintf("DefineSessionFlowFromPrompt: prompt %q not found", promptName)) - } - - turn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State], input *SessionFlowInput) error { - genOpts, err := renderPromptForTurn(ctx, p, sess, defaultInput) - if err != nil { - return err - } - - if len(input.ToolRestarts) > 0 { - for _, part := range input.ToolRestarts { - if !part.IsToolRequest() { - return core.NewError(core.INVALID_ARGUMENT, "ToolRestarts: part is not a tool request") - } - } - genOpts.Resume = ai.NewResume(input.ToolRestarts, nil) - } - - modelResp, err := ai.GenerateWithRequest(ctx, r, genOpts, nil, - func(ctx context.Context, chunk *ai.ModelResponseChunk) error { - resp.SendModelChunk(chunk) - return nil - }, - ) - if err != nil { - return fmt.Errorf("generate: %w", err) - } - - // Replace session messages with the full history minus prompt - // messages. This captures intermediate tool call/response messages - // from the tool loop, not just the final response. - if modelResp.Request != nil { - var msgs []*ai.Message - for _, m := range modelResp.History() { - if m.Metadata[promptMessageKey] == true { - continue - } - msgs = append(msgs, m) - } - sess.SetMessages(msgs) - } else if modelResp.Message != nil { - sess.AddMessages(modelResp.Message) - } - - // Stream interrupt parts so the client can detect and handle them - // (e.g. prompt the user for confirmation). - if modelResp.FinishReason == ai.FinishReasonInterrupted { - if parts := modelResp.Interrupts(); len(parts) > 0 { - resp.SendModelChunk(&ai.ModelResponseChunk{ - Role: ai.RoleTool, - Content: parts, - }) - } - } - return nil - } - - fn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*SessionFlowResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { - return turn(ctx, resp, sess, input) - }) - if err != nil { - return nil, err - } - return sess.Result(), nil - } - - return DefineSessionFlow(r, promptName, fn, opts...) -} - -// renderPromptForTurn renders the prompt with the active input variables -// (session override > default), tags the prompt-rendered messages so they -// can be excluded from history, and appends conversation history. -func renderPromptForTurn[State, PromptIn any]( - ctx context.Context, - p *ai.DataPrompt[PromptIn, string], - sess *SessionRunner[State], - defaultInput PromptIn, -) (*ai.GenerateActionOptions, error) { - promptInput := defaultInput - if stored := sess.InputVariables(); stored != nil { - typed, ok := base.ConvertTo[PromptIn](stored) - if !ok { - return nil, core.NewError(core.INVALID_ARGUMENT, - "input variables type mismatch: got %T, want %T", stored, promptInput) - } - promptInput = typed - } - - genOpts, err := p.Render(ctx, promptInput) - if err != nil { - return nil, fmt.Errorf("prompt render: %w", err) - } - - for _, m := range genOpts.Messages { - if m.Metadata == nil { - m.Metadata = make(map[string]any) - } - m.Metadata[promptMessageKey] = true - } - - genOpts.Messages = append(genOpts.Messages, sess.Messages()...) - return genOpts, nil -} diff --git a/go/ai/exp/session_flow_test.go b/go/ai/exp/agent_test.go similarity index 83% rename from go/ai/exp/session_flow_test.go rename to go/ai/exp/agent_test.go index 26bf20d842..20c07f662d 100644 --- a/go/ai/exp/session_flow_test.go +++ b/go/ai/exp/agent_test.go @@ -44,13 +44,13 @@ func newTestRegistry(t *testing.T) *registry.Registry { return registry.New() } -func TestSessionFlow_BasicMultiTurn(t *testing.T) { +func TestAgent_BasicMultiTurn(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "basicFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "basicFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. if len(input.Messages) > 0 { @@ -119,14 +119,14 @@ func TestSessionFlow_BasicMultiTurn(t *testing.T) { } } -func TestSessionFlow_WithSessionStore(t *testing.T) { +func TestAgent_WithSessionStore(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "snapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "snapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -189,14 +189,14 @@ func TestSessionFlow_WithSessionStore(t *testing.T) { } } -func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { +func TestAgent_ResumeFromSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "resumeFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "resumeFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -279,13 +279,13 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { } } -func TestSessionFlow_ClientManagedState(t *testing.T) { +func TestAgent_ClientManagedState(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "clientStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "clientStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -342,13 +342,13 @@ func TestSessionFlow_ClientManagedState(t *testing.T) { } } -func TestSessionFlow_Artifacts(t *testing.T) { +func TestAgent_Artifacts(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "artifactFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "artifactFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendArtifact(&Artifact{ Name: "code.go", @@ -373,7 +373,7 @@ func TestSessionFlow_Artifacts(t *testing.T) { if err != nil { return nil, err } - return &SessionFlowResult{Artifacts: sess.Artifacts()}, nil + return &AgentResult{Artifacts: sess.Artifacts()}, nil }, ) @@ -412,16 +412,16 @@ func TestSessionFlow_Artifacts(t *testing.T) { } } -func TestSessionFlow_SnapshotCallback(t *testing.T) { +func TestAgent_SnapshotCallback(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() // Only snapshot on even turns. callbackCalls := 0 - af := DefineSessionFlow(reg, "callbackFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "callbackFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -471,13 +471,13 @@ func TestSessionFlow_SnapshotCallback(t *testing.T) { } } -func TestSessionFlow_SendMessages(t *testing.T) { +func TestAgent_SendMessages(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "sendMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "sendMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { return nil }) }, @@ -517,14 +517,14 @@ func TestSessionFlow_SendMessages(t *testing.T) { } } -func TestSessionFlow_SessionContext(t *testing.T) { +func TestAgent_SessionContext(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) var retrievedCounter int - af := DefineSessionFlow(reg, "contextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "contextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { // Session should be retrievable from context. ctxSess := SessionFromContext[testState](ctx) if ctxSess == nil { @@ -563,13 +563,13 @@ func TestSessionFlow_SessionContext(t *testing.T) { } } -func TestSessionFlow_ErrorInTurn(t *testing.T) { +func TestAgent_ErrorInTurn(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "errorFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "errorFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { return fmt.Errorf("turn failed") }) }, @@ -589,13 +589,13 @@ func TestSessionFlow_ErrorInTurn(t *testing.T) { } } -func TestSessionFlow_SetMessages(t *testing.T) { +func TestAgent_SetMessages(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "setMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "setMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { // Replace all messages with just one. sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) return nil @@ -759,14 +759,14 @@ func TestInMemorySessionStore(t *testing.T) { }) } -func TestSessionFlow_TurnSpanOutput(t *testing.T) { +func TestAgent_TurnSpanOutput(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) var capturedOutputs []any - af := DefineSessionFlow(reg, "turnOutputFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + af := DefineCustomAgent(reg, "turnOutputFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { // Wrap collectTurnOutput to capture what each turn produces. originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { @@ -775,7 +775,7 @@ func TestSessionFlow_TurnSpanOutput(t *testing.T) { return output } - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendStatus(testStatus{Phase: "thinking"}) resp.SendModelChunk(&ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("reply")}, @@ -821,9 +821,9 @@ func TestSessionFlow_TurnSpanOutput(t *testing.T) { } for i, output := range capturedOutputs { - chunks, ok := output.([]*SessionFlowStreamChunk[testStatus]) + chunks, ok := output.([]*AgentStreamChunk[testStatus]) if !ok { - t.Fatalf("turn %d: expected []*SessionFlowStreamChunk[testStatus], got %T", i, output) + t.Fatalf("turn %d: expected []*AgentStreamChunk[testStatus], got %T", i, output) } // 3 content chunks per turn: status + model chunk + artifact. if len(chunks) != 3 { @@ -837,15 +837,15 @@ func TestSessionFlow_TurnSpanOutput(t *testing.T) { } } -func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { +func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() var capturedOutputs []any - af := DefineSessionFlow(reg, "turnOutputSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { output := originalCollect() @@ -853,7 +853,7 @@ func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { return output } - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendStatus(testStatus{Phase: "working"}) sess.AddMessages(ai.NewModelTextMessage("reply")) return nil @@ -891,7 +891,7 @@ func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { if len(capturedOutputs) != 1 { t.Fatalf("expected 1 captured output, got %d", len(capturedOutputs)) } - chunks := capturedOutputs[0].([]*SessionFlowStreamChunk[testStatus]) + chunks := capturedOutputs[0].([]*AgentStreamChunk[testStatus]) if len(chunks) != 1 { t.Errorf("expected 1 content chunk, got %d", len(chunks)) } @@ -950,9 +950,7 @@ func TestPromptAgent_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineSessionFlowFromPrompt[testState, any]( - reg, "testPrompt", nil, - ) + af := DefinePromptAgent[testState, any](reg, "testPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1009,67 +1007,6 @@ func TestPromptAgent_Basic(t *testing.T) { } } -func TestPromptAgent_PromptInputOverride(t *testing.T) { - ctx := context.Background() - reg := setupPromptTestRegistry(t) - - type greetInput struct { - Name string `json:"name"` - } - - ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", - ai.WithModelName("test/echo"), - ai.WithPrompt("Hello {{name}}!"), - ) - - af := DefineSessionFlowFromPrompt[testState]( - reg, "greetPrompt", greetInput{Name: "default"}, - ) - - // Use WithPromptInput to override. - conn, err := af.StreamBidi(ctx, - WithInputVariables[testState](greetInput{Name: "override"}), - ) - if err != nil { - t.Fatalf("StreamBidi failed: %v", err) - } - - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText failed: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } - conn.Close() - - response, err := conn.Output() - if err != nil { - t.Fatalf("Output failed: %v", err) - } - - // Verify the override was stored in session state. - if response.State.InputVariables == nil { - t.Fatal("expected PromptInput in state") - } - - // The model echoes the last user message, which is "hi". - // But the prompt was rendered with "override" so "Hello override!" should appear - // in the messages sent to the model (verified via the echo). - // We primarily verify the state was set correctly. - inputMap, ok := response.State.InputVariables.(map[string]any) - if !ok { - t.Fatalf("expected PromptInput to be map[string]any, got %T", response.State.InputVariables) - } - if name, _ := inputMap["name"].(string); name != "override" { - t.Errorf("expected PromptInput name='override', got %q", name) - } -} - func TestPromptAgent_MultiTurnHistory(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) @@ -1100,9 +1037,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - af := DefineSessionFlowFromPrompt[testState, any]( - reg, "historyPrompt", nil, - ) + af := DefinePromptAgent[testState, any](reg, "historyPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1166,7 +1101,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { } } -func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { +func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) store := NewInMemorySessionStore[testState]() @@ -1176,15 +1111,11 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineSessionFlowFromPrompt[testState, any]( - reg, "snapPrompt", nil, + af := DefinePromptAgent[testState, any](reg, "snapPrompt", nil, WithSessionStore(store), ) - // Start with prompt input. - conn, err := af.StreamBidi(ctx, - WithInputVariables[testState](map[string]any{"key": "value"}), - ) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } @@ -1204,21 +1135,10 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { if err != nil { t.Fatalf("Output failed: %v", err) } - if resp.SnapshotID == "" { t.Fatal("expected snapshot ID") } - // Verify the snapshot contains PromptInput. - snap, err := store.GetSnapshot(ctx, resp.SnapshotID) - if err != nil { - t.Fatalf("GetSnapshot failed: %v", err) - } - if snap.State.InputVariables == nil { - t.Error("expected InputVariables in snapshot state") - } - - // Resume from snapshot — the PromptInput should be preserved. conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) if err != nil { t.Fatalf("StreamBidi with snapshot failed: %v", err) @@ -1240,7 +1160,6 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // Verify state via snapshot (server-managed state). snap2, err := store.GetSnapshot(ctx, resp2.SnapshotID) if err != nil { t.Fatalf("GetSnapshot failed: %v", err) @@ -1248,9 +1167,6 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { if got := len(snap2.State.Messages); got != 4 { t.Errorf("expected 4 messages after resume, got %d", got) } - if snap2.State.InputVariables == nil { - t.Error("expected PromptInput preserved after resume") - } } func TestPromptAgent_ToolLoopMessages(t *testing.T) { @@ -1336,7 +1252,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ai.WithTools(ai.ToolName("greet"), ai.ToolName("farewell")), ) - af := DefineSessionFlowFromPrompt[testState, any](reg, "toolPrompt", nil) + af := DefinePromptAgent[testState, any](reg, "toolPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1413,13 +1329,13 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { } } -func TestSessionFlow_RunText(t *testing.T) { +func TestAgent_RunText(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "runTextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "runTextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text)) } @@ -1446,13 +1362,13 @@ func TestSessionFlow_RunText(t *testing.T) { } } -func TestSessionFlow_Run(t *testing.T) { +func TestAgent_Run(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "runFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "runFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -1461,7 +1377,7 @@ func TestSessionFlow_Run(t *testing.T) { }, ) - input := &SessionFlowInput{ + input := &AgentInput{ Messages: []*ai.Message{ ai.NewUserTextMessage("msg1"), ai.NewUserTextMessage("msg2"), @@ -1479,13 +1395,13 @@ func TestSessionFlow_Run(t *testing.T) { } } -func TestSessionFlow_RunText_WithState(t *testing.T) { +func TestAgent_RunText_WithState(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "runStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "runStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1519,14 +1435,14 @@ func TestSessionFlow_RunText_WithState(t *testing.T) { } } -func TestSessionFlow_RunText_WithSnapshot(t *testing.T) { +func TestAgent_RunText_WithSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "runSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "runSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1575,7 +1491,7 @@ func TestPromptAgent_RunText(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineSessionFlowFromPrompt[testState, any](reg, "runTextPrompt", nil) + af := DefinePromptAgent[testState, any](reg, "runTextPrompt", nil) response, err := af.RunText(ctx, "hello") if err != nil { @@ -1591,14 +1507,14 @@ func TestPromptAgent_RunText(t *testing.T) { } } -func TestSessionFlow_SingleTurnSnapshotDedup(t *testing.T) { +func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "dedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "dedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1635,14 +1551,14 @@ func TestSessionFlow_SingleTurnSnapshotDedup(t *testing.T) { } } -func TestSessionFlow_MultiTurnSnapshotDedup(t *testing.T) { +func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "multiDedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "multiDedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1697,14 +1613,14 @@ func TestSessionFlow_MultiTurnSnapshotDedup(t *testing.T) { } } -func TestSessionFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { +func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "postRunMutateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "postRunMutateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil }); err != nil { @@ -1748,6 +1664,128 @@ func TestSessionFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) } } +// TestAgent_FnPanicReturnsError verifies that a panic inside the agent +// function is recovered and surfaced as an error, rather than crashing the +// process or hanging the streaming goroutine. +func TestAgent_FnPanicReturnsError(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "panicFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + resp.SendStatus(testStatus{Phase: "before-panic"}) + panic("boom") + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("trigger"); err != nil { + t.Fatalf("SendText: %v", err) + } + + done := make(chan error, 1) + go func() { + for chunk, err := range conn.Receive() { + _ = chunk + if err != nil { + done <- err + return + } + } + _, outErr := conn.Output() + done <- outErr + }() + + select { + case err := <-done: + if err == nil { + t.Fatal("expected error from panicking fn") + } + if !strings.Contains(err.Error(), "panicked") { + t.Errorf("expected panic error, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Receive/Output hung; streaming goroutine likely leaked") + } + + conn.Close() +} + +// TestAgent_CancelDuringStreamReleasesGoroutine verifies that cancelling the +// context mid-stream does not deadlock the streaming goroutine on outCh send. +func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + reg := newTestRegistry(t) + + emitting := make(chan struct{}) + fnDone := make(chan struct{}) + af := DefineCustomAgent(reg, "cancelFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + defer close(fnDone) + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + close(emitting) + // Emit until ctx cancels. Without the goroutine's + // ctx-aware drain, this would deadlock once the consumer + // stops reading. + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + resp.SendStatus(testStatus{Phase: "tick"}) + } + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + + <-emitting + cancel() + + select { + case <-fnDone: + case <-time.After(2 * time.Second): + t.Fatal("agent fn did not return after ctx cancel; goroutine deadlock") + } + conn.Close() +} + +// TestAgent_DefineAgent_StateMismatchPanics verifies that passing a typed +// AgentOption (e.g., a session store) with a different State than the +// declared one on DefineAgent panics with a clear message. +func TestAgent_DefineAgent_StateMismatchPanics(t *testing.T) { + type otherState struct{ X int } + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic on State mismatch") + } + msg := fmt.Sprintf("%v", r) + if !strings.Contains(msg, "does not match agent State") { + t.Errorf("panic message missing expected substring: %s", msg) + } + }() + + _ = DefineAgent[otherState](reg, "mismatch", WithSessionStore(store)) +} + // --- Detach, transform, and getSnapshot tests --- // waitForSnapshot polls the store for a snapshot matching the predicate, @@ -1775,15 +1813,15 @@ func waitForSnapshot[State any]( return nil } -func TestSessionFlow_TurnEnd_CarriesSnapshotID(t *testing.T) { +func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { // Sanity: each TurnEnd chunk carries the snapshot ID of the turn-end // snapshot, and the snapshots themselves are persisted. reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "turnEndSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "turnEndSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil }) @@ -1834,7 +1872,7 @@ func TestSessionFlow_TurnEnd_CarriesSnapshotID(t *testing.T) { } } -func TestSessionFlow_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { +func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { // Detach lands while turn 0 (input A) is mid-fn and an extra turn // (the detach input D itself) is waiting. The pending snapshot must: // - Be written with empty state and no parent (A was suspended, so @@ -1847,9 +1885,9 @@ func TestSessionFlow_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) entered := make(chan struct{}, 4) release := make(chan struct{}) - af := DefineSessionFlow(reg, "detachInFlight", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "detachInFlight", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { entered <- struct{}{} <-release sess.AddMessages(ai.NewModelTextMessage("reply-" + input.Messages[0].Text())) @@ -1937,7 +1975,7 @@ func TestSessionFlow_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) } } -func TestSessionFlow_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { +func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { // Run two normal turns first, then detach during a third (in-flight) // turn. The pending snapshot must chain off the second turn's snapshot. reg := newTestRegistry(t) @@ -1946,9 +1984,9 @@ func TestSessionFlow_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { enter := make(chan struct{}, 4) release := make(chan struct{}, 4) - af := DefineSessionFlow(reg, "detachChainParent", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "detachChainParent", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { enter <- struct{}{} <-release sess.AddMessages(ai.NewModelTextMessage("ok")) @@ -2021,12 +2059,12 @@ func TestSessionFlow_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { }) } -func TestSessionFlow_Detach_RequiresStore(t *testing.T) { +func TestAgent_Detach_RequiresStore(t *testing.T) { reg := newTestRegistry(t) - af := DefineSessionFlow(reg, "detachNoStore", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "detachNoStore", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { return nil }) }, @@ -2050,7 +2088,7 @@ func TestSessionFlow_Detach_RequiresStore(t *testing.T) { } } -func TestSessionFlow_Detach_PendingThenComplete(t *testing.T) { +func TestAgent_Detach_PendingThenComplete(t *testing.T) { // Client detaches mid-flow; flow finishes naturally; pending snapshot // flips to status=complete with the full session state. reg := newTestRegistry(t) @@ -2059,9 +2097,9 @@ func TestSessionFlow_Detach_PendingThenComplete(t *testing.T) { release := make(chan struct{}) entered := make(chan struct{}) - af := DefineSessionFlow(reg, "detachComplete", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "detachComplete", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: case <-ctx.Done(): @@ -2143,7 +2181,7 @@ func TestSessionFlow_Detach_PendingThenComplete(t *testing.T) { } } -func TestSessionFlow_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { +func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { // SendArtifact must behave the same way regardless of whether detach // has landed: the artifact is added to the session and shows up in // the finalized snapshot's state. The wire forward is the only thing @@ -2154,9 +2192,9 @@ func TestSessionFlow_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) detached := make(chan struct{}) release := make(chan struct{}) - af := DefineSessionFlow(reg, "detachArtifact", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "detachArtifact", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendArtifact(&Artifact{ Name: "before.txt", Parts: []*ai.Part{ai.NewTextPart("pre-detach")}, @@ -2223,7 +2261,7 @@ func TestSessionFlow_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) } } -func TestSessionFlow_Detach_FlowErrorsBecomesError(t *testing.T) { +func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() @@ -2231,9 +2269,9 @@ func TestSessionFlow_Detach_FlowErrorsBecomesError(t *testing.T) { entered := make(chan struct{}) boom := errors.New("kaboom") - af := DefineSessionFlow(reg, "detachErr", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "detachErr", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: case <-time.After(time.Second): @@ -2292,7 +2330,7 @@ func TestSessionFlow_Detach_FlowErrorsBecomesError(t *testing.T) { } } -func TestSessionFlow_Detach_AbortSnapshotStopsFlow(t *testing.T) { +func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { // Client detaches, then calls AbortSnapshot. The store's status // subscriber notifies the runtime, which cancels the work context, and // the finalizer rewrites the snapshot with status=canceled. @@ -2301,9 +2339,9 @@ func TestSessionFlow_Detach_AbortSnapshotStopsFlow(t *testing.T) { entered := make(chan struct{}) - af := DefineSessionFlow(reg, "detachAbort", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "detachAbort", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: case <-time.After(time.Second): @@ -2364,15 +2402,15 @@ func TestSessionFlow_Detach_AbortSnapshotStopsFlow(t *testing.T) { } } -func TestSessionFlow_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { +func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { // Sanity: a non-detached invocation against a store-backed flow still // behaves like a synchronous flow (turn-end snapshots, no pending row). reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "syncStillWorks", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "syncStillWorks", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil }) @@ -2418,7 +2456,7 @@ func TestSessionFlow_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { } } -func TestSessionFlow_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { +func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { // Without detach, a client cancel still cancels the work — this is // the regression guard for "until detach=true is called, this is a // normal HTTP/WS connection that cancels on close." @@ -2428,9 +2466,9 @@ func TestSessionFlow_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { entered := make(chan struct{}) exited := make(chan error, 1) - af := DefineSessionFlow(reg, "syncCancel", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "syncCancel", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: case <-ctx.Done(): @@ -2473,7 +2511,7 @@ func TestSessionFlow_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { } } -func TestSessionFlow_ResumeFromErrorSnapshot_Rejected(t *testing.T) { +func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() @@ -2490,8 +2528,8 @@ func TestSessionFlow_ResumeFromErrorSnapshot_Rejected(t *testing.T) { t.Fatalf("SaveSnapshot: %v", err) } - af := DefineSessionFlow(reg, "resumeErrored", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + af := DefineCustomAgent(reg, "resumeErrored", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(store), @@ -2506,7 +2544,7 @@ func TestSessionFlow_ResumeFromErrorSnapshot_Rejected(t *testing.T) { } } -func TestSessionFlow_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { +func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() @@ -2522,9 +2560,9 @@ func TestSessionFlow_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { return s } - af := DefineSessionFlow(reg, "transformedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "transformedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("the secret is out")) return nil }) @@ -2599,11 +2637,11 @@ func TestInMemorySessionStore_GetSnapshot_NotFound(t *testing.T) { } } -func TestSessionFlow_GetSnapshotAction_NoStore(t *testing.T) { +func TestAgent_GetSnapshotAction_NoStore(t *testing.T) { reg := newTestRegistry(t) - DefineSessionFlow(reg, "noStoreFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + DefineCustomAgent(reg, "noStoreFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { return nil, nil }, ) @@ -2639,11 +2677,11 @@ func (minimalStore[State]) SaveSnapshot( return nil, nil } -func TestSessionFlow_AgentMetadata(t *testing.T) { +func TestAgent_AgentMetadata(t *testing.T) { // Verify the metadata["agent"] payload on the flow's action descriptor // correctly reports stateManagement and abortable for each combination // of store capabilities. - noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + noopFn := func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { return nil, nil } @@ -2656,7 +2694,7 @@ func TestSessionFlow_AgentMetadata(t *testing.T) { { name: "no store → client-managed, not abortable", define: func(reg api.Registry, flowName string) { - DefineSessionFlow(reg, flowName, noopFn) + DefineCustomAgent(reg, flowName, noopFn) }, wantMgmt: AgentMetadataStateManagementClient, wantAbortab: false, @@ -2664,7 +2702,7 @@ func TestSessionFlow_AgentMetadata(t *testing.T) { { name: "store missing abort capabilities → server-managed, not abortable", define: func(reg api.Registry, flowName string) { - DefineSessionFlow(reg, flowName, noopFn, + DefineCustomAgent(reg, flowName, noopFn, WithSessionStore[testState](minimalStore[testState]{})) }, wantMgmt: AgentMetadataStateManagementServer, @@ -2673,7 +2711,7 @@ func TestSessionFlow_AgentMetadata(t *testing.T) { { name: "store with full capabilities → server-managed, abortable", define: func(reg api.Registry, flowName string) { - DefineSessionFlow(reg, flowName, noopFn, + DefineCustomAgent(reg, flowName, noopFn, WithSessionStore(NewInMemorySessionStore[testState]())) }, wantMgmt: AgentMetadataStateManagementServer, @@ -2687,7 +2725,7 @@ func TestSessionFlow_AgentMetadata(t *testing.T) { flowName := "metaFlow" tc.define(reg, flowName) - act := core.ResolveActionFor[*SessionFlowInit[testState], *SessionFlowOutput[testState], *SessionFlowStreamChunk[testStatus], *SessionFlowInput]( + act := core.ResolveActionFor[*AgentInit[testState], *AgentOutput[testState], *AgentStreamChunk[testStatus], *AgentInput]( reg, api.ActionTypeFlow, flowName) if act == nil { t.Fatal("flow action not registered") @@ -2711,15 +2749,15 @@ func TestSessionFlow_AgentMetadata(t *testing.T) { } } -func TestSessionFlow_AbortAction_GatedOnCapabilities(t *testing.T) { +func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { // Verify the abort companion action is only registered when the // store implements SnapshotAborter. The getSnapshot action is // registered regardless. t.Run("aborter capability → both registered", func(t *testing.T) { reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() // implements SnapshotAborter - DefineSessionFlow(reg, "fullCaps", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + DefineCustomAgent(reg, "fullCaps", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(store), @@ -2738,8 +2776,8 @@ func TestSessionFlow_AbortAction_GatedOnCapabilities(t *testing.T) { t.Run("no aborter capability → abort not registered", func(t *testing.T) { reg := newTestRegistry(t) - DefineSessionFlow(reg, "minCaps", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + DefineCustomAgent(reg, "minCaps", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore[testState](minimalStore[testState]{}), @@ -2757,19 +2795,19 @@ func TestSessionFlow_AbortAction_GatedOnCapabilities(t *testing.T) { }) } -func TestSessionFlow_StateTransform_ClientManagedState(t *testing.T) { +func TestAgent_StateTransform_ClientManagedState(t *testing.T) { reg := newTestRegistry(t) - // Client-managed state: transform should be applied to SessionFlowOutput.State. + // Client-managed state: transform should be applied to AgentOutput.State. transform := func(_ context.Context, s SessionState[testState]) SessionState[testState] { // Zero out the counter to demonstrate the transform is applied. s.Custom.Counter = -1 return s } - af := DefineSessionFlow(reg, "clientXformFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "clientXformFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.UpdateCustom(func(s testState) testState { s.Counter = 7 return s @@ -2792,15 +2830,15 @@ func TestSessionFlow_StateTransform_ClientManagedState(t *testing.T) { } } -func TestSessionFlow_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { +func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { // End-to-end: run a flow that the client detaches from, let it // finalize, then resume from its snapshot as if reconnecting later. reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "resumeDetachedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "resumeDetachedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -2923,7 +2961,7 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { } } -func TestSessionFlow_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { +func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { // An abort that lands while fn is still running but does not actually // stop fn (because fn does not observe ctx) must still result in // status=canceled — the finalizer must not clobber canceled with @@ -2935,9 +2973,9 @@ func TestSessionFlow_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { fnRelease := make(chan struct{}) entered := make(chan struct{}) - af := DefineSessionFlow(reg, "raceFinalize", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "raceFinalize", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: case <-time.After(time.Second): @@ -3061,15 +3099,15 @@ func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { t.Fatal("channel did not close after subscription ctx cancel") } -func TestSessionFlow_AbortSnapshot_NoOpOnTerminal(t *testing.T) { +func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { // Calling AbortSnapshot on an already-terminal snapshot is a no-op // that returns the existing status. reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() - af := DefineSessionFlow(reg, "abortNoop", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + af := DefineCustomAgent(reg, "abortNoop", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil }) diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 1b28cdcd74..830d62c3d1 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -22,11 +22,11 @@ import ( "github.com/firebase/genkit/go/ai" ) -// AgentMetadata is the value placed under metadata["agent"] on a session -// flow's action descriptor. It exposes capability information so the Dev -// UI and other reflective callers can render the right surface (e.g. -// hide the Abort button when the configured store doesn't support it) -// without round-tripping through the reflection API. +// AgentMetadata is the value placed under metadata["agent"] on an agent's +// action descriptor. It exposes capability information so the Dev UI and +// other reflective callers can render the right surface (e.g. hide the +// Abort button when the configured store doesn't support it) without +// round-tripping through the reflection API. type AgentMetadata struct { // Abortable reports whether the agent's invocations can be aborted // (true when the store implements [SnapshotAborter]). @@ -62,9 +62,9 @@ type Artifact struct { Parts []*ai.Part `json:"parts"` } -// SessionFlowInit is the input for starting an session flow invocation. +// AgentInit is the input for starting an agent invocation. // Provide either SnapshotID (to load from store) or State (direct state). -type SessionFlowInit[State any] struct { +type AgentInit[State any] struct { // SnapshotID loads state from a persisted snapshot. // Mutually exclusive with State. SnapshotID string `json:"snapshotId,omitempty"` @@ -73,15 +73,15 @@ type SessionFlowInit[State any] struct { State *SessionState[State] `json:"state,omitempty"` } -// SessionFlowInput is the input sent to an session flow during a conversation turn. -type SessionFlowInput struct { +// AgentInput is the input sent to an agent during a conversation turn. +type AgentInput struct { // Detach signals that the client wishes to disconnect after this input is // accepted. The server writes a single pending snapshot (with empty - // state), returns [SessionFlowOutput] with that snapshot ID, and - // continues processing any already-buffered inputs in a background - // context. The pending snapshot is finalized with the cumulative final - // state once all queued inputs are processed (or the snapshot is - // cancelled via cancelSnapshot). + // state), returns [AgentOutput] with that snapshot ID, and continues + // processing any already-buffered inputs in a background context. The + // pending snapshot is finalized with the cumulative final state once all + // queued inputs are processed (or the snapshot is cancelled via + // cancelSnapshot). Detach bool `json:"detach,omitempty"` // Messages contains the user's input for this turn. Messages []*ai.Message `json:"messages,omitempty"` @@ -92,9 +92,9 @@ type SessionFlowInput struct { ToolRestarts []*ai.Part `json:"toolRestarts,omitempty"` } -// SessionFlowOutput is the output when an session flow invocation completes. -// It wraps SessionFlowResult with framework-managed fields. -type SessionFlowOutput[State any] struct { +// AgentOutput is the output when an agent invocation completes. +// It wraps AgentResult with framework-managed fields. +type AgentOutput[State any] struct { // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` // Message is the last model response message from the conversation. @@ -107,18 +107,18 @@ type SessionFlowOutput[State any] struct { State *SessionState[State] `json:"state,omitempty"` } -// SessionFlowResult is the return value from an SessionFlowFunc. +// AgentResult is the return value from an AgentFunc. // It contains the user-specified outputs of the agent invocation. -type SessionFlowResult struct { +type AgentResult struct { // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` } -// SessionFlowStreamChunk represents a single item in the session flow's output stream. +// AgentStreamChunk represents a single item in the agent's output stream. // Multiple fields can be populated in a single chunk. -type SessionFlowStreamChunk[Stream any] struct { +type AgentStreamChunk[Stream any] struct { // Artifact contains a newly produced artifact. Artifact *Artifact `json:"artifact,omitempty"` // ModelChunk contains generation tokens from the model. @@ -126,7 +126,7 @@ type SessionFlowStreamChunk[Stream any] struct { // Status contains user-defined structured status information. // The Stream type parameter defines the shape of this data. Status Stream `json:"status,omitempty"` - // TurnEnd is non-nil when the session flow has finished processing the current + // TurnEnd is non-nil when the agent has finished processing the current // input. It groups all turn-end signals (snapshot ID, etc.) so callers can // check a single field. When set, the client should stop iterating and may // send the next input. @@ -140,11 +140,8 @@ type SessionState[State any] struct { Artifacts []*Artifact `json:"artifacts,omitempty"` // Custom is the user-defined state associated with this conversation. Custom State `json:"custom,omitempty"` - // InputVariables is the input used for session flows that require input variables - // (e.g. prompt-backed session flows). - InputVariables any `json:"inputVariables,omitempty"` // Messages is the conversation history (user/model exchanges). - // Does NOT include prompt-rendered messages — those are rendered fresh each turn. + // Does NOT include prompt-rendered messages, those are rendered fresh each turn. Messages []*ai.Message `json:"messages,omitempty"` } @@ -163,7 +160,7 @@ const ( SnapshotEventDetach SnapshotEvent = "detach" ) -// TurnEnd groups the signals emitted when a session flow turn finishes. +// TurnEnd groups the signals emitted when an agent turn finishes. // A TurnEnd value is emitted exactly once per turn, regardless of whether a // snapshot was persisted. type TurnEnd struct { diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 630d58edaa..69c9fc7e65 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -21,18 +21,36 @@ import ( "errors" ) -// --- SessionFlowOption --- +// --- AgentDefineOption --- -// SessionFlowOption configures an SessionFlow. -type SessionFlowOption[State any] interface { - applySessionFlow(*sessionFlowOptions[State]) error +// AgentDefineOption is the marker interface for any option that can be passed +// to [DefineAgent]. It is satisfied by every [github.com/firebase/genkit/go/ai.PromptOption] +// (which configures the underlying prompt) and by every [AgentOption] (which +// configures the agent itself). +// +// The State type parameter is phantom: a single concrete option satisfies +// [AgentDefineOption] for any State, so type inference cannot pick a State +// from the variadic. Callers of [DefineAgent] must specify [State] explicitly +// (use [any] when no typed Custom state is needed). +type AgentDefineOption[State any] interface { + isAgentDefineOption() +} + +// --- AgentOption --- + +// AgentOption configures an agent at definition time. It also satisfies +// [AgentDefineOption] so it can be passed to [DefineAgent] alongside +// [github.com/firebase/genkit/go/ai.PromptOption] values. +type AgentOption[State any] interface { + AgentDefineOption[State] + applyAgent(*agentOptions[State]) error } // StateTransform rewrites session state on its way out to a client. It // is applied to the State returned by the getSnapshot companion action -// and to [SessionFlowOutput.State] when state is client-managed (no -// store). It is not applied to state persisted in the store or to -// state passed to the user flow function. +// and to [AgentResult.State] when state is client-managed (no store). +// It is not applied to state persisted in the store or to state passed +// to the user agent function. // // ctx is the request or invocation context: cancellation, deadlines, // and context-scoped values (e.g. the caller's identity for RBAC-aware @@ -42,13 +60,15 @@ type SessionFlowOption[State any] interface { // may mutate and return it, or return a freshly-constructed value. type StateTransform[State any] = func(ctx context.Context, state SessionState[State]) SessionState[State] -type sessionFlowOptions[State any] struct { +type agentOptions[State any] struct { store SessionStore[State] callback SnapshotCallback[State] transform StateTransform[State] } -func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error { +func (*agentOptions[State]) isAgentDefineOption() {} + +func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { if o.store != nil { if opts.store != nil { return errors.New("cannot set session store more than once (WithSessionStore)") @@ -74,20 +94,20 @@ func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[St // implement [SnapshotReader] and [SnapshotWriter] at minimum. Detach // support also requires [SnapshotAborter]; detach attempts on a store // that lacks that interface are rejected at runtime. -func WithSessionStore[State any](store SessionStore[State]) SessionFlowOption[State] { - return &sessionFlowOptions[State]{store: store} +func WithSessionStore[State any](store SessionStore[State]) AgentOption[State] { + return &agentOptions[State]{store: store} } // WithSnapshotCallback configures when snapshots are created. // If not provided and a store is configured, snapshots are always created. -func WithSnapshotCallback[State any](cb SnapshotCallback[State]) SessionFlowOption[State] { - return &sessionFlowOptions[State]{callback: cb} +func WithSnapshotCallback[State any](cb SnapshotCallback[State]) AgentOption[State] { + return &agentOptions[State]{callback: cb} } // WithSnapshotOn configures snapshots to be created only for the specified events. // For example, WithSnapshotOn[MyState](SnapshotEventTurnEnd) skips the // invocation-end snapshot. -func WithSnapshotOn[State any](events ...SnapshotEvent) SessionFlowOption[State] { +func WithSnapshotOn[State any](events ...SnapshotEvent) AgentOption[State] { set := make(map[SnapshotEvent]struct{}, len(events)) for _, e := range events { set[e] = struct{}{} @@ -100,25 +120,23 @@ func WithSnapshotOn[State any](events ...SnapshotEvent) SessionFlowOption[State] // WithStateTransform registers a transform applied to session state on // its way out to a client via the getSnapshot companion action or via -// [SessionFlowOutput.State] when state is client-managed. Typical use -// is PII redaction or stripping secrets. The transform is not applied -// to state persisted in the store or to state passed to the user flow -// function. -func WithStateTransform[State any](transform StateTransform[State]) SessionFlowOption[State] { - return &sessionFlowOptions[State]{transform: transform} +// [AgentResult.State] when state is client-managed. Typical use is PII +// redaction or stripping secrets. The transform is not applied to state +// persisted in the store or to state passed to the user agent function. +func WithStateTransform[State any](transform StateTransform[State]) AgentOption[State] { + return &agentOptions[State]{transform: transform} } // --- InvocationOption --- -// InvocationOption configures an session flow invocation (StreamBidi, Run, or RunText). +// InvocationOption configures an agent invocation (StreamBidi, Run, or RunText). type InvocationOption[State any] interface { applyInvocation(*invocationOptions[State]) error } type invocationOptions[State any] struct { - state *SessionState[State] - snapshotID string - promptInput any + state *SessionState[State] + snapshotID string } func (o *invocationOptions[State]) applyInvocation(opts *invocationOptions[State]) error { @@ -140,12 +158,6 @@ func (o *invocationOptions[State]) applyInvocation(opts *invocationOptions[State } opts.snapshotID = o.snapshotID } - if o.promptInput != nil { - if opts.promptInput != nil { - return errors.New("cannot set prompt input more than once (WithPromptInput)") - } - opts.promptInput = o.promptInput - } return nil } @@ -160,9 +172,3 @@ func WithState[State any](state *SessionState[State]) InvocationOption[State] { func WithSnapshotID[State any](id string) InvocationOption[State] { return &invocationOptions[State]{snapshotID: id} } - -// WithInputVariables overrides the default input variables for a prompt-backed session flow. -// Used with DefineSessionFlowFromPrompt to customize the input variables per invocation. -func WithInputVariables[State any](input any) InvocationOption[State] { - return &invocationOptions[State]{promptInput: input} -} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 4b57282efa..0fecacec0c 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -37,11 +37,11 @@ import ( // written for synchronous turns or invocations are always [SnapshotStatusComplete] // (an empty value is also treated as complete for backwards compatibility). // -// When a client sets [SessionFlowInput.Detach], the server writes a single +// When a client sets [AgentInput.Detach], the server writes a single // snapshot with [SnapshotStatusPending] (and empty state) and returns its // ID immediately. Background processing then either rewrites that snapshot // with the cumulative final state and [SnapshotStatusComplete] / -// [SnapshotStatusError] when the flow finishes, or with +// [SnapshotStatusError] when the agent finishes, or with // [SnapshotStatusCanceled] if the client called abortSnapshot in the // meantime. type SnapshotStatus string @@ -182,7 +182,7 @@ type SnapshotWriter[State any] interface { } // SnapshotAborter is the optional capability layered on [SessionStore] -// that lets a session flow's invocations be aborted. It bundles the two +// that lets an agent's invocations be aborted. It bundles the two // methods that must be implemented together for the abort lifecycle to // function: // @@ -190,9 +190,9 @@ type SnapshotWriter[State any] interface { // to canceled (typically called by the abortSnapshot companion // action or directly by a Go caller holding the store). // -// - [SnapshotAborter.OnSnapshotStatusChange] lets the session flow -// runtime observe the flip without polling, so it can promptly -// cancel the work context. +// - [SnapshotAborter.OnSnapshotStatusChange] lets the agent runtime +// observe the flip without polling, so it can promptly cancel the +// work context. // // They are bundled because neither is useful alone: flipping status // with no observer means the running fn never learns it was aborted; @@ -207,7 +207,7 @@ type SnapshotAborter interface { // nil if the snapshot is not found. // // Implementations must perform the read-and-write atomically (e.g., a - // transaction or a compare-and-swap). The session flow's abortSnapshot + // transaction or a compare-and-swap). The agent's abortSnapshot // action and finalizer rely on this to avoid a pending row being // clobbered by a racing terminal write. AbortSnapshot(ctx context.Context, snapshotID string) (*SnapshotMetadata, error) @@ -427,9 +427,10 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta // --- Snapshot companion actions --- -// GetSnapshotRequest is the input for a session flow's getSnapshot companion -// action. The action is registered at `{flowName}/getSnapshot` when the flow -// is defined and is intended for Dev UI and client-side reconnect flows. +// GetSnapshotRequest is the input for an agent's getSnapshot companion +// action. The action is registered at `{agentName}/getSnapshot` when the +// agent is defined and is intended for Dev UI and client-side reconnect +// flows. type GetSnapshotRequest struct { // SnapshotID identifies the snapshot to fetch. SnapshotID string `json:"snapshotId"` @@ -476,30 +477,30 @@ type AbortSnapshotResponse struct { Status SnapshotStatus `json:"status,omitempty"` } -// registerSnapshotActions registers the session flow's companion actions: +// registerSnapshotActions registers the agent's companion actions: // -// - The flow's name under [api.ActionTypeAgentSnapshot] — getSnapshot, +// - The agent's name under [api.ActionTypeAgentSnapshot] — getSnapshot, // registered whenever a [SessionStore] is configured. The action is // the remote counterpart to [SessionStore.GetSnapshot] for Dev UI and // non-Go clients; local Go callers use the store reference directly. // -// - The flow's name under [api.ActionTypeAgentAbort] — abortSnapshot, -// registered only when the store implements [SnapshotAborter] -// (which bundles both the abort trigger and the status-change -// subscription needed for the runtime to react). Surfacing the -// action only when the capability is present keeps the reflected -// API aligned with what the store can actually do. +// - The agent's name under [api.ActionTypeAgentAbort] — abortSnapshot, +// registered only when the store implements [SnapshotAborter] (which +// bundles both the abort trigger and the status-change subscription +// needed for the runtime to react). Surfacing the action only when +// the capability is present keeps the reflected API aligned with +// what the store can actually do. func registerSnapshotActions[State any]( r api.Registry, - flowName string, + agentName string, store SessionStore[State], transform StateTransform[State], ) { - core.DefineAction(r, flowName, api.ActionTypeAgentSnapshot, nil, nil, + core.DefineAction(r, agentName, api.ActionTypeAgentSnapshot, nil, nil, func(ctx context.Context, req *GetSnapshotRequest) (*GetSnapshotResponse[State], error) { if store == nil { return nil, core.NewError(core.FAILED_PRECONDITION, - "getSnapshot: session flow %q has no session store configured", flowName) + "getSnapshot: agent %q has no session store configured", agentName) } if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId is required") @@ -540,7 +541,7 @@ func registerSnapshotActions[State any]( // action. return } - core.DefineAction(r, flowName, api.ActionTypeAgentAbort, nil, nil, + core.DefineAction(r, agentName, api.ActionTypeAgentAbort, nil, nil, func(ctx context.Context, req *AbortSnapshotRequest) (*AbortSnapshotResponse, error) { if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "abortSnapshot: snapshotId is required") @@ -625,13 +626,6 @@ func (s *Session[State]) UpdateCustom(fn func(State) State) { s.version++ } -// InputVariables returns the prompt input stored in the session state. -func (s *Session[State]) InputVariables() any { - s.mu.RLock() - defer s.mu.RUnlock() - return s.state.InputVariables -} - // Artifacts returns the current artifacts. func (s *Session[State]) Artifacts() []*Artifact { s.mu.RLock() @@ -677,11 +671,11 @@ func (s *Session[State]) UpdateArtifacts(fn func([]*Artifact) []*Artifact) { func (s *Session[State]) copyStateLocked() SessionState[State] { bytes, err := json.Marshal(s.state) if err != nil { - panic(fmt.Sprintf("session flow: failed to marshal state: %v", err)) + panic(fmt.Sprintf("agent: failed to marshal state: %v", err)) } var copied SessionState[State] if err := json.Unmarshal(bytes, &copied); err != nil { - panic(fmt.Sprintf("session flow: failed to unmarshal state: %v", err)) + panic(fmt.Sprintf("agent: failed to unmarshal state: %v", err)) } return copied } diff --git a/go/ai/option.go b/go/ai/option.go index d51ca85a4a..fce215d45a 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -31,6 +31,16 @@ type PromptFn = func(context.Context, any) (string, error) // MessagesFn is a function that generates messages. type MessagesFn = func(context.Context, any) ([]*Message, error) +// No-op marker so the exp package's AgentDefineOption marker interface is +// satisfied by every PromptOption. Lets DefineAgent accept a mixed variadic +// of prompt options and agent-only options. +func (*configOptions) isAgentDefineOption() {} +func (*commonGenOptions) isAgentDefineOption() {} +func (*inputOptions) isAgentDefineOption() {} +func (*promptOptions) isAgentDefineOption() {} +func (*promptingOptions) isAgentDefineOption() {} +func (*outputOptions) isAgentDefineOption() {} + // configOptions holds configuration options. type configOptions struct { Config any // Primitive (model, embedder, retriever, etc) configuration. diff --git a/go/core/action.go b/go/core/action.go index f2ce0cc877..11b8336300 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -541,6 +541,11 @@ type BidiConnection[StreamIn, StreamOut, Out any] struct { // Send sends an input message to the bidi action. // Returns an error if the connection is closed or the context is cancelled. func (c *BidiConnection[StreamIn, StreamOut, Out]) Send(input StreamIn) (err error) { + // Recover from "send on closed channel" panic. A check-then-send under the + // mutex would race with Close, and holding the mutex across the send would + // deadlock against Close when the buffer is full. Closing inputCh (rather + // than a separate signal channel) is required so receivers can use the + // canonical `for ... range inputCh` idiom. defer func() { if r := recover(); r != nil { err = NewError(FAILED_PRECONDITION, "connection is closed") diff --git a/go/core/api/action.go b/go/core/api/action.go index 14020d3a5b..a258b9a09f 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -66,7 +66,7 @@ const ( ActionTypeAgentAbort ActionType = "agent-abort" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" - ActionTypeSessionFlow ActionType = "session-flow" + ActionTypeAgent ActionType = "agent" ) // ActionDesc is a descriptor of an action. diff --git a/go/core/schemas.config b/go/core/schemas.config index 49098ded15..be8e797a53 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1196,32 +1196,32 @@ Metadata contains additional artifact-specific data. . # ---------------------------------------------------------------------------- -# SessionFlowInput +# AgentInput # ---------------------------------------------------------------------------- -SessionFlowInput pkg ai/exp +AgentInput pkg ai/exp -SessionFlowInput doc -SessionFlowInput is the input sent to an session flow during a conversation turn. +AgentInput doc +AgentInput is the input sent to an agent during a conversation turn. . -SessionFlowInput.detach doc +AgentInput.detach doc Detach signals that the client wishes to disconnect after this input is accepted. The server writes a single pending snapshot (with empty -state), returns [SessionFlowOutput] with that snapshot ID, and -continues processing any already-buffered inputs in a background -context. The pending snapshot is finalized with the cumulative final -state once all queued inputs are processed (or the snapshot is -cancelled via cancelSnapshot). +state), returns [AgentOutput] with that snapshot ID, and continues +processing any already-buffered inputs in a background context. The +pending snapshot is finalized with the cumulative final state once all +queued inputs are processed (or the snapshot is cancelled via +cancelSnapshot). . -SessionFlowInput.messages type []*ai.Message -SessionFlowInput.messages doc +AgentInput.messages type []*ai.Message +AgentInput.messages doc Messages contains the user's input for this turn. . -SessionFlowInput.toolRestarts type []*ai.Part -SessionFlowInput.toolRestarts doc +AgentInput.toolRestarts type []*ai.Part +AgentInput.toolRestarts doc ToolRestarts contains tool request parts to re-execute interrupted tools. Use [ai.ToolDef.RestartWith] to create these parts from an interrupted tool request. When set, the generate call resumes with these restarts @@ -1229,110 +1229,110 @@ instead of treating Messages as tool responses. . # ---------------------------------------------------------------------------- -# SessionFlowInit +# AgentInit # ---------------------------------------------------------------------------- -SessionFlowInit pkg ai/exp -SessionFlowInit typeparams [State any] +AgentInit pkg ai/exp +AgentInit typeparams [State any] -SessionFlowInit doc -SessionFlowInit is the input for starting an session flow invocation. +AgentInit doc +AgentInit is the input for starting an agent invocation. Provide either SnapshotID (to load from store) or State (direct state). . -SessionFlowInit.snapshotId doc +AgentInit.snapshotId doc SnapshotID loads state from a persisted snapshot. Mutually exclusive with State. . -SessionFlowInit.state type *SessionState[State] -SessionFlowInit.state doc +AgentInit.state type *SessionState[State] +AgentInit.state doc State provides direct state for the invocation. Mutually exclusive with SnapshotID. . # ---------------------------------------------------------------------------- -# SessionFlowResult +# AgentResult # ---------------------------------------------------------------------------- -SessionFlowResult pkg ai/exp +AgentResult pkg ai/exp -SessionFlowResult doc -SessionFlowResult is the return value from an SessionFlowFunc. +AgentResult doc +AgentResult is the return value from an AgentFunc. It contains the user-specified outputs of the agent invocation. . -SessionFlowResult.message type *ai.Message -SessionFlowResult.message doc +AgentResult.message type *ai.Message +AgentResult.message doc Message is the last model response message from the conversation. . -SessionFlowResult.artifacts doc +AgentResult.artifacts doc Artifacts contains artifacts produced during the session. . # ---------------------------------------------------------------------------- -# SessionFlowOutput +# AgentOutput # ---------------------------------------------------------------------------- -SessionFlowOutput pkg ai/exp -SessionFlowOutput typeparams [State any] +AgentOutput pkg ai/exp +AgentOutput typeparams [State any] -SessionFlowOutput doc -SessionFlowOutput is the output when an session flow invocation completes. -It wraps SessionFlowResult with framework-managed fields. +AgentOutput doc +AgentOutput is the output when an agent invocation completes. +It wraps AgentResult with framework-managed fields. . -SessionFlowOutput.snapshotId doc +AgentOutput.snapshotId doc SnapshotID is the ID of the snapshot created at the end of this invocation. Empty if no snapshot was created (callback returned false or no store configured). . -SessionFlowOutput.state type *SessionState[State] -SessionFlowOutput.state doc +AgentOutput.state type *SessionState[State] +AgentOutput.state doc State contains the final conversation state. Only populated when state is client-managed (no store configured). . -SessionFlowOutput.message type *ai.Message -SessionFlowOutput.message doc +AgentOutput.message type *ai.Message +AgentOutput.message doc Message is the last model response message from the conversation. . -SessionFlowOutput.artifacts doc +AgentOutput.artifacts doc Artifacts contains artifacts produced during the session. . # ---------------------------------------------------------------------------- -# SessionFlowStreamChunk +# AgentStreamChunk # ---------------------------------------------------------------------------- -SessionFlowStreamChunk pkg ai/exp -SessionFlowStreamChunk typeparams [Stream any] +AgentStreamChunk pkg ai/exp +AgentStreamChunk typeparams [Stream any] -SessionFlowStreamChunk doc -SessionFlowStreamChunk represents a single item in the session flow's output stream. +AgentStreamChunk doc +AgentStreamChunk represents a single item in the agent's output stream. Multiple fields can be populated in a single chunk. . -SessionFlowStreamChunk.modelChunk type *ai.ModelResponseChunk -SessionFlowStreamChunk.modelChunk doc +AgentStreamChunk.modelChunk type *ai.ModelResponseChunk +AgentStreamChunk.modelChunk doc ModelChunk contains generation tokens from the model. . -SessionFlowStreamChunk.status type Stream -SessionFlowStreamChunk.status doc +AgentStreamChunk.status type Stream +AgentStreamChunk.status doc Status contains user-defined structured status information. The Stream type parameter defines the shape of this data. . -SessionFlowStreamChunk.artifact doc +AgentStreamChunk.artifact doc Artifact contains a newly produced artifact. . -SessionFlowStreamChunk.turnEnd type *TurnEnd -SessionFlowStreamChunk.turnEnd doc -TurnEnd is non-nil when the session flow has finished processing the current +AgentStreamChunk.turnEnd type *TurnEnd +AgentStreamChunk.turnEnd doc +TurnEnd is non-nil when the agent has finished processing the current input. It groups all turn-end signals (snapshot ID, etc.) so callers can check a single field. When set, the client should stop iterating and may send the next input. @@ -1345,7 +1345,7 @@ send the next input. TurnEnd pkg ai/exp TurnEnd doc -TurnEnd groups the signals emitted when a session flow turn finishes. +TurnEnd groups the signals emitted when an agent turn finishes. A TurnEnd value is emitted exactly once per turn, regardless of whether a snapshot was persisted. . @@ -1383,11 +1383,6 @@ SessionState.artifacts doc Artifacts are named collections of parts produced during the conversation. . -SessionState.inputVariables doc -InputVariables is the input used for session flows that require input variables -(e.g. prompt-backed session flows). -. - # ---------------------------------------------------------------------------- # SnapshotEvent # ---------------------------------------------------------------------------- @@ -1436,11 +1431,11 @@ AbortSnapshotResponse omit AgentMetadata pkg ai/exp AgentMetadata doc -AgentMetadata is the value placed under metadata["agent"] on a session -flow's action descriptor. It exposes capability information so the Dev -UI and other reflective callers can render the right surface (e.g. -hide the Abort button when the configured store doesn't support it) -without round-tripping through the reflection API. +AgentMetadata is the value placed under metadata["agent"] on an agent's +action descriptor. It exposes capability information so the Dev UI and +other reflective callers can render the right surface (e.g. hide the +Abort button when the configured store doesn't support it) without +round-tripping through the reflection API. . AgentMetadata.stateManagement doc diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 429ad95a50..0916ab0225 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -430,74 +430,47 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn return core.DefineBidiFlow(g.reg, name, fn) } -// DefineSessionFlow defines a custom session flow with full control over the -// conversation loop, registers it as a [core.Action] of type Flow, and -// returns an [aix.SessionFlow]. +// DefineAgent defines an agent that wraps a prompt defined inline from the +// given options, registers both under name as actions on the registry, and +// returns an [aix.Agent]. // // Experimental: This API is under active development and may change in any // minor version release. // -// An SessionFlow is a stateful, multi-turn conversational flow. It builds on +// An Agent is a stateful, multi-turn conversational flow. It builds on // bidirectional streaming to enable ongoing conversations where each turn's // input and output are streamed between client and server. The framework // handles session state, conversation history, and optional snapshot // persistence automatically. // -// The provided function fn receives a [aix.Responder] for streaming output -// to the client and an [aix.SessionRunner] for accessing conversation state. -// Call [aix.SessionRunner.Run] to enter the turn loop, which blocks until the -// client sends the next message. +// opts is a mixed list of [ai.PromptOption] values (which configure the +// underlying prompt) and [aix.AgentOption] values (which configure the agent +// itself). The State type parameter must be specified explicitly: use [any] +// when no typed Custom state is needed; use [Foo] when an +// [aix.SessionStore[Foo]] is provided. Mismatches panic at definition time. // -// For prompt-backed agents that follow a standard render-generate-stream loop, -// use [DefineSessionFlowFromPrompt] instead. +// For an agent backed by an existing prompt, use [DefinePromptAgent]. For +// full control over the per-turn loop, use [DefineCustomAgent]. // // # Options // -// - [aix.WithSessionStore]: Enable snapshot persistence with a [aix.SessionStore] +// - any [ai.PromptOption]: e.g., [ai.WithModel], [ai.WithSystem], [ai.WithTools] +// - [aix.WithSessionStore]: Enable snapshot persistence // - [aix.WithSnapshotCallback]: Control when snapshots are created // - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types // -// Type parameters: -// - Stream: Type for custom status updates sent via [aix.Responder.SendStatus] -// - State: Type for user-defined state persisted in snapshots -// // Example: // -// chatAgent := genkit.DefineSessionFlow(g, "chat", -// func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.SessionFlowResult, error) { -// var lastMessage *ai.Message -// err := sess.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { -// sess.AddMessages(input.Messages...) -// for result, err := range genkit.GenerateStream(ctx, g, -// ai.WithModelName("googleai/gemini-3-flash-preview"), -// ai.WithMessages(sess.Messages()...), -// ) { -// if err != nil { -// return err -// } -// if result.Done { -// lastMessage = result.Response.Message -// sess.AddMessages(lastMessage) -// } else { -// resp.SendModelChunk(result.Chunk) -// } -// } -// return nil -// }) -// if err != nil { -// return nil, err -// } -// return &aix.SessionFlowResult{Message: lastMessage}, nil -// }, +// chatAgent := genkit.DefineAgent[any](g, "chat", +// ai.WithModelName("googleai/gemini-3-flash-preview"), +// ai.WithSystem("You are a helpful assistant."), +// aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), // ) // -// // Start a conversation: // conn, err := chatAgent.StreamBidi(ctx) // if err != nil { // // handle error // } -// -// // Send a message and stream the response: // conn.SendText("Hello!") // for chunk, err := range conn.Receive() { // if chunk.TurnEnd != nil { @@ -506,33 +479,26 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn // fmt.Print(chunk.ModelChunk.Text()) // } // conn.Close() -func DefineSessionFlow[Stream, State any]( +func DefineAgent[State any]( g *Genkit, name string, - fn aix.SessionFlowFunc[Stream, State], - opts ...aix.SessionFlowOption[State], -) *aix.SessionFlow[Stream, State] { - return aix.DefineSessionFlow(g.reg, name, fn, opts...) + opts ...aix.AgentDefineOption[State], +) *aix.Agent[any, State] { + return aix.DefineAgent(g.reg, name, opts...) } -// DefineSessionFlowFromPrompt defines a prompt-backed session flow, registers it as a -// [core.Action] of type Flow, and returns an [aix.SessionFlow]. +// DefinePromptAgent defines an agent backed by a prompt already registered +// with the registry (via [DefinePrompt] or loaded from a .prompt file). The +// agent is registered under the same name as the prompt. // // Experimental: This API is under active development and may change in any // minor version release. // -// This is a higher-level alternative to [DefineSessionFlow] for agents backed -// by a prompt (defined via [DefinePrompt] or loaded from a .prompt file). The -// conversation loop is handled automatically: each turn renders the prompt, -// appends conversation history, calls the model with streaming, and updates -// session state. -// -// The prompt is looked up by promptName from the registry. The defaultInput -// provides template variables for prompt rendering (e.g., personality, tone) -// and can be overridden per invocation via [aix.WithInputVariables]. +// defaultInput is used to render the prompt on every turn. PromptIn is +// captured for compile-time type checking on defaultInput. // -// DefineSessionFlowFromPrompt accepts the same options as [DefineSessionFlow]. See -// [DefineSessionFlow] for available options. +// For an agent that defines its prompt inline, use [DefineAgent]. For full +// control over the per-turn loop, use [DefineCustomAgent]. // // Type parameters: // - State: Type for user-defined state persisted in snapshots @@ -540,47 +506,83 @@ func DefineSessionFlow[Stream, State any]( // // Example: // -// // Given a .prompt file "chat.prompt" loaded via WithPromptDir: -// // --- -// // model: googleai/gemini-3-flash-preview -// // input: -// // schema: -// // personality: string -// // --- -// // {{role "system"}} -// // You are {{personality}}. -// // type ChatInput struct { // Personality string `json:"personality"` // } // -// chatAgent := genkit.DefineSessionFlowFromPrompt(g, "chat", -// ChatInput{Personality: "a helpful assistant"}, +// chatAgent := genkit.DefinePromptAgent[any](g, "chat", +// ChatInput{Personality: "a sarcastic pirate"}, // aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), // ) -// -// // Start a conversation: -// conn, err := chatAgent.StreamBidi(ctx) -// if err != nil { -// // handle error -// } -// -// // Send a message and stream the response: -// conn.SendText("Hello!") -// for chunk, err := range conn.Receive() { -// if chunk.TurnEnd != nil { -// break -// } -// fmt.Print(chunk.ModelChunk.Text()) -// } -// conn.Close() -func DefineSessionFlowFromPrompt[State, PromptIn any]( +func DefinePromptAgent[State, PromptIn any]( g *Genkit, promptName string, defaultInput PromptIn, - opts ...aix.SessionFlowOption[State], -) *aix.SessionFlow[any, State] { - return aix.DefineSessionFlowFromPrompt(g.reg, promptName, defaultInput, opts...) + opts ...aix.AgentOption[State], +) *aix.Agent[any, State] { + return aix.DefinePromptAgent(g.reg, promptName, defaultInput, opts...) +} + +// DefineCustomAgent defines an agent with full control over the conversation +// loop, registers it as a [core.Action] of type Flow, and returns an +// [aix.Agent]. +// +// Experimental: This API is under active development and may change in any +// minor version release. +// +// The provided function fn receives a [aix.Responder] for streaming output +// to the client and an [aix.AgentSession] for accessing conversation state. +// Call [aix.AgentSession.Run] to enter the turn loop, which blocks until the +// client sends the next message. +// +// For agents backed by a prompt, use [DefineAgent] (inline) or +// [DefinePromptAgent] (existing prompt) instead. +// +// # Options +// +// - [aix.WithSessionStore]: Enable snapshot persistence with a [aix.SessionStore] +// - [aix.WithSnapshotCallback]: Control when snapshots are created +// - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types +// +// Type parameters: +// - Stream: Type for custom status updates sent via [aix.Responder.SendStatus] +// - State: Type for user-defined state persisted in snapshots +// +// Example: +// +// chatAgent := genkit.DefineCustomAgent(g, "chat", +// func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentResult, error) { +// var lastMessage *ai.Message +// err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { +// for result, err := range genkit.GenerateStream(ctx, g, +// ai.WithModelName("googleai/gemini-3-flash-preview"), +// ai.WithMessages(sess.Messages()...), +// ) { +// if err != nil { +// return err +// } +// if result.Done { +// lastMessage = result.Response.Message +// sess.AddMessages(lastMessage) +// } else { +// resp.SendModelChunk(result.Chunk) +// } +// } +// return nil +// }) +// if err != nil { +// return nil, err +// } +// return &aix.AgentResult{Message: lastMessage}, nil +// }, +// ) +func DefineCustomAgent[Stream, State any]( + g *Genkit, + name string, + fn aix.AgentFunc[Stream, State], + opts ...aix.AgentOption[State], +) *aix.Agent[Stream, State] { + return aix.DefineCustomAgent(g.reg, name, fn, opts...) } // Run executes the given function `fn` within the context of the current flow run, diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index a05c908557..0e647499d3 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates the SessionFlow API for multi-turn conversation +// This sample demonstrates the custom Agent API for multi-turn conversation // with token-level streaming. It runs a CLI REPL where conversation history // is managed automatically by the session. package main @@ -35,9 +35,9 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatFlow := genkit.DefineSessionFlow(g, "chat", - func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.SessionFlowResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { + chatAgent := genkit.DefineCustomAgent(g, "chat", + func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ @@ -67,10 +67,10 @@ func main() { aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) - fmt.Println("Session Flow Chat (type 'quit' to exit)") + fmt.Println("Agent Chat (type 'quit' to exit)") fmt.Println() - conn, err := chatFlow.StreamBidi(ctx) + conn, err := chatAgent.StreamBidi(ctx) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index b501872520..6d0e2f8697 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates DefineSessionFlowFromPrompt, which creates a -// multi-turn conversational session flow backed by a .prompt file. The -// conversation loop (render prompt, call model, stream chunks, update history) -// is handled automatically. Compare with custom-agent which wires -// the same loop manually. +// This sample demonstrates DefinePromptAgent, which creates a multi-turn +// conversational agent backed by a .prompt file. The conversation loop +// (render prompt, call model, stream chunks, update history) is handled +// automatically. Compare with custom-agent which wires the same loop +// manually. package main import ( @@ -39,18 +39,18 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatFlow := genkit.DefineSessionFlowFromPrompt( - g, "chat", ChatPromptInput{Personality: "a sarcastic pirate"}, + chatAgent := genkit.DefinePromptAgent[any](g, "chat", + ChatPromptInput{Personality: "a sarcastic pirate"}, aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 }), ) - fmt.Println("Session Flow Chat (type 'quit' to exit)") + fmt.Println("Agent Chat (type 'quit' to exit)") fmt.Println() - conn, err := chatFlow.StreamBidi(ctx) + conn, err := chatAgent.StreamBidi(ctx) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index c9542c64e0..e234053d42 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -118,85 +118,85 @@ class AbortSnapshotResponse(GenkitModel): status: SnapshotStatus | None = None -class AgentMetadata(GenkitModel): - """Model for agentmetadata data.""" +class AgentInit(GenkitModel): + """Model for agentinit data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - state_management: AgentMetadataStateManagement = Field(...) - abortable: bool = Field(...) + snapshot_id: str | None = None + state: SessionState | None = None -class Artifact(GenkitModel): - """Model for artifact data.""" +class AgentInput(GenkitModel): + """Model for agentinput data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - name: str | None = None - parts: list[Part] = Field(...) - metadata: Metadata | None = None + detach: bool | None = None + messages: list[MessageData] | None = None + tool_restarts: list[Part] | None = None -class GetSnapshotRequest(GenkitModel): - """Model for getsnapshotrequest data.""" +class AgentMetadata(GenkitModel): + """Model for agentmetadata data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str = Field(...) + state_management: AgentMetadataStateManagement = Field(...) + abortable: bool = Field(...) -class GetSnapshotResponse(GenkitModel): - """Model for getsnapshotresponse data.""" +class AgentOutput(GenkitModel): + """Model for agentoutput data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str = Field(...) - created_at: str | None = None - updated_at: str | None = None - status: SnapshotStatus | None = None - error: str | None = None + snapshot_id: str | None = None state: SessionState | None = None + message: MessageData | None = None + artifacts: list[Artifact] | None = None -class SessionFlowInit(GenkitModel): - """Model for sessionflowinit data.""" +class AgentResult(GenkitModel): + """Model for agentresult data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str | None = None - state: SessionState | None = None + message: MessageData | None = None + artifacts: list[Artifact] | None = None -class SessionFlowInput(GenkitModel): - """Model for sessionflowinput data.""" +class AgentStreamChunk(GenkitModel): + """Model for agentstreamchunk data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - detach: bool | None = None - messages: list[MessageData] | None = None - tool_restarts: list[Part] | None = None + model_chunk: ModelResponseChunk | None = None + status: Any | None = Field(default=None) + artifact: Artifact | None = None + turn_end: TurnEnd | None = None -class SessionFlowOutput(GenkitModel): - """Model for sessionflowoutput data.""" +class Artifact(GenkitModel): + """Model for artifact data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str | None = None - state: SessionState | None = None - message: MessageData | None = None - artifacts: list[Artifact] | None = None + name: str | None = None + parts: list[Part] = Field(...) + metadata: Metadata | None = None -class SessionFlowResult(GenkitModel): - """Model for sessionflowresult data.""" +class GetSnapshotRequest(GenkitModel): + """Model for getsnapshotrequest data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - message: MessageData | None = None - artifacts: list[Artifact] | None = None + snapshot_id: str = Field(...) -class SessionFlowStreamChunk(GenkitModel): - """Model for sessionflowstreamchunk data.""" +class GetSnapshotResponse(GenkitModel): + """Model for getsnapshotresponse data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - model_chunk: ModelResponseChunk | None = None - status: Any | None = Field(default=None) - artifact: Artifact | None = None - turn_end: TurnEnd | None = None + snapshot_id: str = Field(...) + created_at: str | None = None + updated_at: str | None = None + status: SnapshotStatus | None = None + error: str | None = None + state: SessionState | None = None class SessionSnapshot(GenkitModel): @@ -220,7 +220,6 @@ class SessionState(GenkitModel): messages: list[MessageData] | None = None custom: Any | None = Field(default=None) artifacts: list[Artifact] | None = None - input_variables: Any | None = Field(default=None) class SnapshotMetadata(GenkitModel): From 29e64f4ded69eff15ab4dfa25d81c2b02b2bb350 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 11 May 2026 18:37:32 -0700 Subject: [PATCH 063/141] refactor(go/exp): structured snapshot errors, ToolResume, status rename, generator type-params * `core.GenkitError` gains `MarshalJSON`/`UnmarshalJSON` so it serializes as the canonical `{status, message, details}` wire shape, plus an `AsGenkitError` helper that wraps non-GenkitError values with status `INTERNAL`. Snapshot-related types now hold `*core.GenkitError` directly instead of a string. * `SnapshotStatus` renamed to `pending`, `succeeded`, `aborted`, `failed` (was `pending`, `complete`, `canceled`, `error`). All Go consts, doc, and test strings updated. * `INTERNAL` `StatusName` value changed from `"INTERNAL_SERVER_ERROR"` to the canonical `"INTERNAL"` so Go matches the JSON schema and JS runtime. The googlegenai workaround is dropped. * `AgentInput.toolRestarts` replaced with a structured `Resume` field. In Go this is `*exp.ToolResume` (`{Respond, Restart}`); the resume schema is inlined in the JS schema and renamed via schemas.config. * `SessionSnapshot` is now hand-written in `go/ai/exp/session.go` and removed from the JS schema (it's runtime-internal). State on snapshots becomes a pointer so pending snapshots can hold nil. * `AgentMetadataStateManagement` renamed to `AgentStateManagement`, `AgentSession` renamed to `SessionRunner`. * `jsonschemagen` learns to auto-forward type parameters on field references: a ref to a type with `typeparams [State any]` now resolves to `*Foo[State]` automatically. The previously-omitted generic snapshot types are now generated through schemas.config. Note: API surface changes here are intentional; the `exp` package is explicitly experimental. --- genkit-tools/common/src/types/agent.ts | 57 ++--- genkit-tools/genkit-schema.json | 98 +++---- go/ai/exp/agent.go | 135 +++++----- go/ai/exp/agent_test.go | 211 +++++++-------- go/ai/exp/gen.go | 199 ++++++++++++--- go/ai/exp/session.go | 144 ++--------- go/core/error.go | 69 ++++- go/core/error_test.go | 90 +++++++ go/core/schemas.config | 240 ++++++++++++++++-- go/core/status_types.go | 2 +- go/genkit/genkit.go | 6 +- go/genkit/servers_test.go | 2 +- .../cmd/jsonschemagen/jsonschemagen.go | 46 +++- .../cmd/jsonschemagen/jsonschemagen_test.go | 21 ++ go/plugins/googlegenai/errors.go | 8 +- go/plugins/middleware/retry.go | 2 +- go/samples/custom-agent/main.go | 2 +- 17 files changed, 860 insertions(+), 472 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 5db3c7b774..63b649b8c7 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -16,7 +16,11 @@ import { z } from 'zod'; import { MessageSchema, ModelResponseChunkSchema } from './model'; -import { PartSchema } from './parts'; +import { + PartSchema, + ToolRequestPartSchema, + ToolResponsePartSchema, +} from './parts'; /** * Zod schema for an artifact produced during a session. @@ -55,17 +59,17 @@ export type SnapshotEvent = z.infer; * - `pending`: a detached invocation is still processing the queued inputs. * The snapshot's state is empty until the flow exits, at which point it * is rewritten with the cumulative final state and a terminal status. - * - `complete`: the snapshot captures a settled state. - * - `canceled`: the snapshot's invocation was aborted via the + * - `succeeded`: the snapshot captures a settled state. + * - `aborted`: the snapshot's invocation was aborted via the * `abortSnapshot` companion action while detached. - * - `error`: the invocation terminated with an error. The snapshot's `error` + * - `failed`: the invocation terminated with an error. The snapshot's `error` * field describes the failure and resume is rejected with that same error. */ export const SnapshotStatusSchema = z.enum([ 'pending', - 'complete', - 'canceled', - 'error', + 'succeeded', + 'aborted', + 'failed', ]); export type SnapshotStatus = z.infer; @@ -98,8 +102,13 @@ export const AgentInputSchema = z.object({ detach: z.boolean().optional(), /** User's input messages for this turn. */ messages: z.array(MessageSchema).optional(), - /** Tool request parts to re-execute interrupted tools. */ - toolRestarts: z.array(PartSchema).optional(), + /** Options for resuming an interrupted generation. */ + resume: z + .object({ + respond: z.array(ToolResponsePartSchema).optional(), + restart: z.array(ToolRequestPartSchema).optional(), + }) + .optional(), }); export type AgentInput = z.infer; @@ -195,25 +204,11 @@ export const SnapshotMetadataSchema = z.object({ event: SnapshotEventSchema, /** Lifecycle state of this snapshot. Empty is treated as `complete`. */ status: SnapshotStatusSchema.optional(), - /** Failure message for a snapshot in `error` status. */ - error: z.string().optional(), + /** Structured failure information for a snapshot in `error` status. */ + error: z.any().optional(), }); export type SnapshotMetadata = z.infer; -/** - * Zod schema for a persisted point-in-time capture of session state. - */ -export const SessionSnapshotSchema = SnapshotMetadataSchema.extend({ - /** - * Conversation state. Empty on a pending snapshot (the live state is - * not yet committed; the background invocation is still processing - * queued inputs); populated on terminal snapshots with the cumulative - * final state. - */ - state: SessionStateSchema, -}); -export type SessionSnapshot = z.infer; - /** * Zod schema for the input of an agent's `getSnapshot` companion action. * The action is registered at `{agentName}/getSnapshot` when the agent @@ -239,8 +234,8 @@ export const GetSnapshotResponseSchema = z.object({ updatedAt: z.string().optional(), /** Lifecycle state of the snapshot. */ status: SnapshotStatusSchema.optional(), - /** Populated when status is `error`. */ - error: z.string().optional(), + /** Structured failure information; populated when status is `error`. */ + error: z.any().optional(), /** * Session state captured by the snapshot, after any configured transform. * Empty when status is `pending` or `error`. @@ -281,9 +276,9 @@ export type AbortSnapshotResponse = z.infer; * - `client`: no store; state flows through the agent's invocation init * and output payloads. */ -export const AgentMetadataStateManagementSchema = z.enum(['server', 'client']); -export type AgentMetadataStateManagement = z.infer< - typeof AgentMetadataStateManagementSchema +export const AgentStateManagementSchema = z.enum(['server', 'client']); +export type AgentStateManagement = z.infer< + typeof AgentStateManagementSchema >; /** @@ -295,7 +290,7 @@ export type AgentMetadataStateManagement = z.infer< */ export const AgentMetadataSchema = z.object({ /** Who owns session state for this agent. */ - stateManagement: AgentMetadataStateManagementSchema, + stateManagement: AgentStateManagementSchema, /** * Whether the agent's invocations can be aborted. True only when the * configured store implements the abort lifecycle. diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 69da18df17..a2d4fe9c33 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -52,11 +52,23 @@ "$ref": "#/$defs/Message" } }, - "toolRestarts": { - "type": "array", - "items": { - "$ref": "#/$defs/Part" - } + "resume": { + "type": "object", + "properties": { + "respond": { + "type": "array", + "items": { + "$ref": "#/$defs/ToolResponsePart" + } + }, + "restart": { + "type": "array", + "items": { + "$ref": "#/$defs/ToolRequestPart" + } + } + }, + "additionalProperties": false } }, "additionalProperties": false @@ -65,7 +77,7 @@ "type": "object", "properties": { "stateManagement": { - "$ref": "#/$defs/AgentMetadataStateManagement" + "$ref": "#/$defs/AgentStateManagement" }, "abortable": { "type": "boolean" @@ -77,13 +89,6 @@ ], "additionalProperties": false }, - "AgentMetadataStateManagement": { - "type": "string", - "enum": [ - "server", - "client" - ] - }, "AgentOutput": { "type": "object", "properties": { @@ -120,6 +125,13 @@ }, "additionalProperties": false }, + "AgentStateManagement": { + "type": "string", + "enum": [ + "server", + "client" + ] + }, "AgentStreamChunk": { "type": "object", "properties": { @@ -185,9 +197,7 @@ "status": { "$ref": "#/$defs/SnapshotStatus" }, - "error": { - "type": "string" - }, + "error": {}, "state": { "$ref": "#/$defs/SessionState" } @@ -197,42 +207,6 @@ ], "additionalProperties": false }, - "SessionSnapshot": { - "type": "object", - "properties": { - "snapshotId": { - "type": "string" - }, - "parentId": { - "type": "string" - }, - "createdAt": { - "type": "string" - }, - "updatedAt": { - "type": "string" - }, - "event": { - "$ref": "#/$defs/SnapshotEvent" - }, - "status": { - "$ref": "#/$defs/SnapshotStatus" - }, - "error": { - "type": "string" - }, - "state": { - "$ref": "#/$defs/SessionState" - } - }, - "required": [ - "snapshotId", - "createdAt", - "event", - "state" - ], - "additionalProperties": false - }, "SessionState": { "type": "object", "properties": { @@ -264,26 +238,24 @@ "type": "object", "properties": { "snapshotId": { - "$ref": "#/$defs/SessionSnapshot/properties/snapshotId" + "type": "string" }, "parentId": { - "$ref": "#/$defs/SessionSnapshot/properties/parentId" + "type": "string" }, "createdAt": { - "$ref": "#/$defs/SessionSnapshot/properties/createdAt" + "type": "string" }, "updatedAt": { - "$ref": "#/$defs/SessionSnapshot/properties/updatedAt" + "type": "string" }, "event": { "$ref": "#/$defs/SnapshotEvent" }, "status": { - "$ref": "#/$defs/SessionSnapshot/properties/status" + "$ref": "#/$defs/SnapshotStatus" }, - "error": { - "$ref": "#/$defs/SessionSnapshot/properties/error" - } + "error": {} }, "required": [ "snapshotId", @@ -296,9 +268,9 @@ "type": "string", "enum": [ "pending", - "complete", - "canceled", - "error" + "succeeded", + "aborted", + "failed" ] }, "TurnEnd": { diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 2616ac258b..902d2a1880 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -35,20 +35,20 @@ import ( "github.com/firebase/genkit/go/core/tracing" ) -// --- AgentSession --- +// --- SessionRunner --- -// AgentSession extends Session with agent-runtime functionality: +// SessionRunner extends Session with agent-runtime functionality: // turn management, snapshot persistence, and input channel handling. -type AgentSession[State any] struct { +type SessionRunner[State any] struct { *Session[State] // InputCh is the channel that delivers per-turn inputs from the client. - // It is consumed automatically by [AgentSession.Run], but is exposed + // It is consumed automatically by [SessionRunner.Run], but is exposed // for advanced use cases that need direct access to the input stream // (e.g., custom turn loops or fan-out patterns). InputCh <-chan *AgentInput // TurnIndex is the zero-based index of the current conversation turn. - // It is incremented automatically by [AgentSession.Run], but is exposed + // It is incremented automatically by [SessionRunner.Run], but is exposed // for advanced use cases that need to track or manipulate turn ordering // directly. TurnIndex int @@ -69,7 +69,7 @@ type AgentSession[State any] struct { // parentSnapshotID returns the ID of the most recent snapshot in this // invocation (used to chain new snapshots via ParentID), or "" if no // snapshot has been written yet. -func (s *AgentSession[State]) parentSnapshotID() string { +func (s *SessionRunner[State]) parentSnapshotID() string { if s.lastSnapshot == nil { return "" } @@ -80,7 +80,7 @@ func (s *AgentSession[State]) parentSnapshotID() string { // wrapped in a trace span for observability. Input messages are automatically // added to the session before fn is called. After fn returns successfully, a // TurnEnd chunk is sent and a snapshot check is triggered. -func (s *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) error) error { +func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) error) error { for input := range s.InputCh { spanMeta := &tracing.SpanMetadata{ Name: fmt.Sprintf("agent/turn/%d", s.TurnIndex), @@ -112,7 +112,7 @@ func (s *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Conte // the last message in the conversation history and all artifacts. // It is a convenience for custom agents that don't need to construct the // result manually. -func (s *AgentSession[State]) Result() *AgentResult { +func (s *SessionRunner[State]) Result() *AgentResult { s.mu.RLock() defer s.mu.RUnlock() @@ -137,7 +137,7 @@ func (s *AgentSession[State]) Result() *AgentResult { // the turn-end snapshot — the pending row already captures the // invocation and a single finalize rewrite will record the cumulative // state once the queued inputs drain. -func (s *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { +func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { if event == SnapshotEventTurnEnd && s.intake != nil { if suspended := s.intake.beginTurnEnd(); suspended { return "" @@ -164,7 +164,7 @@ func (s *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE if s.snapshotCallback != nil { var prevState *SessionState[State] if s.lastSnapshot != nil { - prevState = &s.lastSnapshot.State + prevState = s.lastSnapshot.State } if !s.snapshotCallback(ctx, &SnapshotContext[State]{ State: ¤tState, @@ -183,8 +183,8 @@ func (s *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotE return &SessionSnapshot[State]{ ParentID: parentID, Event: event, - Status: SnapshotStatusComplete, - State: currentState, + Status: SnapshotStatusSucceeded, + State: ¤tState, }, nil }) if err != nil { @@ -236,7 +236,7 @@ func (r Responder[Stream]) SendArtifact(artifact *Artifact) { // Type parameters: // - Stream: Type for status updates sent via the responder // - State: Type for user-defined state in snapshots -type AgentFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *AgentSession[State]) (*AgentResult, error) +type AgentFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *SessionRunner[State]) (*AgentResult, error) // Agent is a bidirectional streaming agent with automatic snapshot management. type Agent[Stream, State any] struct { @@ -353,10 +353,10 @@ func DefineCustomAgent[Stream, State any]( // itself is generated from agent.ts; this constructor is hand-written // because it inspects the configured store's optional capabilities. func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { - mgmt := AgentMetadataStateManagementClient + mgmt := AgentStateManagementClient abortable := false if store != nil { - mgmt = AgentMetadataStateManagementServer + mgmt = AgentStateManagementServer _, abortable = store.(SnapshotAborter) } return AgentMetadata{ @@ -376,7 +376,7 @@ type agentRuntime[Stream, State any] struct { cfg *agentOptions[State] session *Session[State] - sess *AgentSession[State] + sess *SessionRunner[State] router *chunkRouter[Stream, State] intake *detachIntake @@ -412,7 +412,7 @@ func newAgentRuntime[Stream, State any]( fnDone: make(chan fnDoneResult[State], 1), } - rt.sess = &AgentSession[State]{ + rt.sess = &SessionRunner[State]{ Session: session, InputCh: rt.intake.out(), snapshotCallback: cfg.callback, @@ -611,14 +611,14 @@ func (rt *agentRuntime[Stream, State]) handleDetach( // trashes any further chunks. rt.router.stopAndWait() - canceledByUser := &atomic.Bool{} + abortedByUser := &atomic.Bool{} subCtx, stopSub := context.WithCancel(workCtx) aborter := rt.cfg.store.(SnapshotAborter) // safe: checkDetachCapabilities ran already statusCh := aborter.OnSnapshotStatusChange(subCtx, pending.SnapshotID) go func() { for status := range statusCh { - if status == SnapshotStatusCanceled { - canceledByUser.Store(true) + if status == SnapshotStatusAborted { + abortedByUser.Store(true) cancelWork() return } @@ -631,7 +631,7 @@ func (rt *agentRuntime[Stream, State]) handleDetach( stopSub() rt.intake.stopAndWait() rt.router.close() - rt.finalizePendingSnapshot(finalizeCtx, pending, res.err, canceledByUser.Load()) + rt.finalizePendingSnapshot(finalizeCtx, pending, res.err, abortedByUser.Load()) cancelWork() }() @@ -639,46 +639,46 @@ func (rt *agentRuntime[Stream, State]) handleDetach( } // finalizePendingSnapshot rewrites the pending snapshot row with the -// terminal state and status. canceledByUser distinguishes a context -// cancellation from abortSnapshot (status=canceled) from an internal -// failure (status=error). The write is funneled through SaveSnapshot +// terminal state and status. abortedByUser distinguishes a context +// cancellation from abortSnapshot (status=aborted) from an internal +// failure (status=failed). The write is funneled through SaveSnapshot // so the read-and-rewrite is one atomic step: if the row has already -// transitioned to canceled (a late abort racing this finalize), +// transitioned to aborted (a late abort racing this finalize), // SaveSnapshot sees it inside fn and we leave the row untouched. func (rt *agentRuntime[Stream, State]) finalizePendingSnapshot( ctx context.Context, pending *SessionSnapshot[State], fnErr error, - canceledByUser bool, + abortedByUser bool, ) { finalState := *rt.session.State() _, err := rt.cfg.store.SaveSnapshot(ctx, pending.SnapshotID, func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error) { // Late abort wins over the terminal we were about to land. - if existing != nil && existing.Status == SnapshotStatusCanceled { + if existing != nil && existing.Status == SnapshotStatusAborted { return nil, nil } - status := SnapshotStatusComplete - errMsg := "" + status := SnapshotStatusSucceeded + var snapErr *core.GenkitError switch { - case canceledByUser: - status = SnapshotStatusCanceled + case abortedByUser: + status = SnapshotStatusAborted if fnErr != nil { - errMsg = fnErr.Error() // canceled wins, preserve text + snapErr = core.AsGenkitError(fnErr) // aborted wins, preserve text } case fnErr != nil: - status = SnapshotStatusError - errMsg = fnErr.Error() + status = SnapshotStatusFailed + snapErr = core.AsGenkitError(fnErr) } return &SessionSnapshot[State]{ ParentID: pending.ParentID, Event: SnapshotEventDetach, Status: status, - Error: errMsg, - State: finalState, + Error: snapErr, + State: &finalState, }, nil }) if err != nil { @@ -723,21 +723,23 @@ func loadSession[State any]( return nil, nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) } switch snap.Status { - case SnapshotStatusError: - msg := snap.Error - if msg == "" { - msg = "snapshot recorded an error" + case SnapshotStatusFailed: + msg := "snapshot recorded an error" + if snap.Error != nil && snap.Error.Message != "" { + msg = snap.Error.Message } return nil, nil, core.NewError(core.FAILED_PRECONDITION, "snapshot %q terminated with error: %s", init.SnapshotID, msg) case SnapshotStatusPending: return nil, nil, core.NewError(core.FAILED_PRECONDITION, "snapshot %q is still pending; wait for it to finalize before resuming", init.SnapshotID) - case SnapshotStatusCanceled: + case SnapshotStatusAborted: return nil, nil, core.NewError(core.FAILED_PRECONDITION, - "snapshot %q was canceled", init.SnapshotID) + "snapshot %q was aborted", init.SnapshotID) + } + if snap.State != nil { + s.state = *snap.State } - s.state = snap.State return s, snap, nil } @@ -988,10 +990,10 @@ func (i *detachIntake) enqueue(input *AgentInput) { // the detach handler. The detach handler then calls suspend to halt // turn-end snapshots while the queued inputs finish processing. // -// A pure detach signal (no Messages, no ToolRestarts) is dropped rather -// than enqueued: it carries no payload to process, so it would just -// trigger a no-op turn. Callers that want to ride a final input on the -// detach signal can do so by calling +// A pure detach signal (no Messages, no Resume payload) is dropped +// rather than enqueued: it carries no payload to process, so it would +// just trigger a no-op turn. Callers that want to ride a final input +// on the detach signal can do so by calling // Send(&AgentInput{Detach: true, Messages: ...}) explicitly. func (i *detachIntake) handleDetach(first *AgentInput) { var drained []*AgentInput @@ -1028,7 +1030,16 @@ drainLoop: // otherwise process. Used to filter pure detach signals out of the // queue so they don't trigger no-op turns. func hasInputPayload(in *AgentInput) bool { - return in != nil && (len(in.Messages) > 0 || len(in.ToolRestarts) > 0) + if in == nil { + return false + } + if len(in.Messages) > 0 { + return true + } + if in.Resume != nil && (len(in.Resume.Respond) > 0 || len(in.Resume.Restart) > 0) { + return true + } + return false } // forward pops the queue and writes to dst at the runner's pace. The @@ -1101,7 +1112,7 @@ func (i *detachIntake) detachSignal() <-chan struct{} { return i.detachCh } -// beginTurnEnd is called by [AgentSession.maybeSnapshot] before writing +// beginTurnEnd is called by [SessionRunner.maybeSnapshot] before writing // a turn-end snapshot. If the intake has been suspended (detach landed), // it returns suspended=true and the runner skips the snapshot. // @@ -1146,7 +1157,7 @@ const promptMessageKey = "_genkit_prompt" // nil for [DefineAgent], where the inline-defined prompt has no per-turn // input. func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[any, State] { - return func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) (*AgentResult, error) { + return func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { actionOpts, err := prompt.Render(ctx, defaultInput) if err != nil { @@ -1165,15 +1176,14 @@ func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) Ag // Append conversation history after the base messages. actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) - // If tool restarts were provided, set the resume option so - // handleResumeOption re-executes the interrupted tools. - if len(input.ToolRestarts) > 0 { - for _, p := range input.ToolRestarts { - if !p.IsToolRequest() { - return core.NewError(core.INVALID_ARGUMENT, "ToolRestarts: part is not a tool request") - } + // If a resume payload was provided, forward it to the + // generate call so handleResumeOption re-executes the + // interrupted tools and / or applies the responses. + if input.Resume != nil { + actionOpts.Resume = &ai.GenerateActionResume{ + Respond: input.Resume.Respond, + Restart: input.Resume.Restart, } - actionOpts.Resume = ai.NewResume(input.ToolRestarts, nil) } modelResp, err := ai.GenerateWithRequest(ctx, r, actionOpts, nil, @@ -1361,10 +1371,11 @@ func (c *AgentConnection[Stream, State]) SendText(text string) error { }) } -// SendToolRestarts sends tool restart parts to resume interrupted tool calls. -// Parts should be created via [ai.ToolDef.RestartWith]. -func (c *AgentConnection[Stream, State]) SendToolRestarts(parts ...*ai.Part) error { - return c.conn.Send(&AgentInput{ToolRestarts: parts}) +// SendResume sends a resume payload to continue an interrupted generation. +// Construct the payload with [ai.ToolDef.RestartWith] or +// [ai.ToolDef.RespondWith] parts. +func (c *AgentConnection[Stream, State]) SendResume(resume *ToolResume) error { + return c.conn.Send(&AgentInput{Resume: resume}) } // Detach asks the server to write a pending snapshot, close the diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 20c07f662d..32bfb39e48 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -49,7 +49,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "basicFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. @@ -125,7 +125,7 @@ func TestAgent_WithSessionStore(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "snapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -195,7 +195,7 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "resumeFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -284,7 +284,7 @@ func TestAgent_ClientManagedState(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "clientStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -347,7 +347,7 @@ func TestAgent_Artifacts(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "artifactFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendArtifact(&Artifact{ @@ -420,7 +420,7 @@ func TestAgent_SnapshotCallback(t *testing.T) { // Only snapshot on even turns. callbackCalls := 0 af := DefineCustomAgent(reg, "callbackFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -476,7 +476,7 @@ func TestAgent_SendMessages(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "sendMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { return nil }) @@ -523,7 +523,7 @@ func TestAgent_SessionContext(t *testing.T) { var retrievedCounter int af := DefineCustomAgent(reg, "contextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { // Session should be retrievable from context. ctxSess := SessionFromContext[testState](ctx) @@ -568,7 +568,7 @@ func TestAgent_ErrorInTurn(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "errorFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { return fmt.Errorf("turn failed") }) @@ -594,7 +594,7 @@ func TestAgent_SetMessages(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "setMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { // Replace all messages with just one. sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) @@ -650,8 +650,8 @@ func TestInMemorySessionStore(t *testing.T) { t.Errorf("expected nil existing on first save, got %+v", existing) } return &SessionSnapshot[testState]{ - Status: SnapshotStatusComplete, - State: SessionState[testState]{Custom: testState{Counter: 1}}, + Status: SnapshotStatusSucceeded, + State: &SessionState[testState]{Custom: testState{Counter: 1}}, }, nil }) if err != nil { @@ -671,8 +671,8 @@ func TestInMemorySessionStore(t *testing.T) { if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ - Status: SnapshotStatusComplete, - State: SessionState[testState]{Custom: testState{Counter: 1}}, + Status: SnapshotStatusSucceeded, + State: &SessionState[testState]{Custom: testState{Counter: 1}}, }, nil }); err != nil { t.Fatalf("SaveSnapshot: %v", err) @@ -697,7 +697,7 @@ func TestInMemorySessionStore(t *testing.T) { if saved.SnapshotID == "" { t.Error("expected store to generate SnapshotID") } - if saved.Status != SnapshotStatusComplete { + if saved.Status != SnapshotStatusSucceeded { t.Errorf("expected Status=complete by default, got %q", saved.Status) } }) @@ -706,7 +706,7 @@ func TestInMemorySessionStore(t *testing.T) { store := NewInMemorySessionStore[testState]() if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - return &SessionSnapshot[testState]{Status: SnapshotStatusComplete}, nil + return &SessionSnapshot[testState]{Status: SnapshotStatusSucceeded}, nil }); err != nil { t.Fatalf("seed: %v", err) } @@ -731,7 +731,7 @@ func TestInMemorySessionStore(t *testing.T) { store := NewInMemorySessionStore[testState]() saved, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - return &SessionSnapshot[testState]{Status: SnapshotStatusComplete}, nil + return &SessionSnapshot[testState]{Status: SnapshotStatusSucceeded}, nil }) if err != nil { t.Fatalf("seed: %v", err) @@ -743,8 +743,8 @@ func TestInMemorySessionStore(t *testing.T) { t.Fatal("expected non-nil existing on update") } return &SessionSnapshot[testState]{ - Status: SnapshotStatusComplete, - State: SessionState[testState]{Custom: testState{Counter: 2}}, + Status: SnapshotStatusSucceeded, + State: &SessionState[testState]{Custom: testState{Counter: 2}}, }, nil }) if err != nil { @@ -766,7 +766,7 @@ func TestAgent_TurnSpanOutput(t *testing.T) { var capturedOutputs []any af := DefineCustomAgent(reg, "turnOutputFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { // Wrap collectTurnOutput to capture what each turn produces. originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { @@ -845,7 +845,7 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { var capturedOutputs []any af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { output := originalCollect() @@ -1334,7 +1334,7 @@ func TestAgent_RunText(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "runTextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text)) @@ -1367,7 +1367,7 @@ func TestAgent_Run(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "runFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { if len(input.Messages) > 0 { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -1400,7 +1400,7 @@ func TestAgent_RunText_WithState(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "runStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -1441,7 +1441,7 @@ func TestAgent_RunText_WithSnapshot(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "runSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -1513,7 +1513,7 @@ func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "dedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -1557,7 +1557,7 @@ func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "multiDedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -1619,7 +1619,7 @@ func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "postRunMutateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil @@ -1672,7 +1672,7 @@ func TestAgent_FnPanicReturnsError(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "panicFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendStatus(testStatus{Phase: "before-panic"}) panic("boom") @@ -1726,7 +1726,7 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { emitting := make(chan struct{}) fnDone := make(chan struct{}) af := DefineCustomAgent(reg, "cancelFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { defer close(fnDone) return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { close(emitting) @@ -1820,7 +1820,7 @@ func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "turnEndSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil @@ -1886,7 +1886,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { release := make(chan struct{}) af := DefineCustomAgent(reg, "detachInFlight", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { entered <- struct{}{} <-release @@ -1950,8 +1950,8 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { if pending.Status != SnapshotStatusPending { t.Errorf("pending snapshot status = %q, want pending", pending.Status) } - if got := len(pending.State.Messages); got != 0 { - t.Errorf("pending state messages = %d, want 0 (live state not yet committed)", got) + if pending.State != nil { + t.Errorf("pending state = %+v, want nil (live state not yet committed)", pending.State) } // No separate turn-end snapshot for A should have been written. @@ -1964,7 +1964,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { close(release) final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusComplete + return s.Status == SnapshotStatusSucceeded }) if final.State.Custom.Counter != 2 { t.Errorf("final counter = %d, want 2 (A + D both processed)", final.State.Custom.Counter) @@ -1985,7 +1985,7 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { release := make(chan struct{}, 4) af := DefineCustomAgent(reg, "detachChainParent", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { enter <- struct{}{} <-release @@ -2055,7 +2055,7 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { // Release remaining turns and let finalize run. close(release) waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusComplete + return s.Status == SnapshotStatusSucceeded }) } @@ -2063,7 +2063,7 @@ func TestAgent_Detach_RequiresStore(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "detachNoStore", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { return nil }) @@ -2090,7 +2090,7 @@ func TestAgent_Detach_RequiresStore(t *testing.T) { func TestAgent_Detach_PendingThenComplete(t *testing.T) { // Client detaches mid-flow; flow finishes naturally; pending snapshot - // flips to status=complete with the full session state. + // flips to status=succeeded with the full session state. reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() @@ -2098,7 +2098,7 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachComplete", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: @@ -2163,15 +2163,15 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { if pending.Status != SnapshotStatusPending { t.Errorf("expected status=%q, got %q", SnapshotStatusPending, pending.Status) } - if got := len(pending.State.Messages); got != 0 { - t.Errorf("pending snapshot should not carry message history, got %d messages", got) + if pending.State != nil { + t.Errorf("pending snapshot should not carry session state, got %+v", pending.State) } // Release; finalizer rewrites the snapshot with the terminal state. close(release) finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusComplete + return s.Status == SnapshotStatusSucceeded }) if finalSnap.State.Custom.Counter != 42 { t.Errorf("expected counter=42 in final snapshot, got %d", finalSnap.State.Custom.Counter) @@ -2193,7 +2193,7 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { release := make(chan struct{}) af := DefineCustomAgent(reg, "detachArtifact", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendArtifact(&Artifact{ Name: "before.txt", @@ -2246,7 +2246,7 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { close(release) final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusComplete + return s.Status == SnapshotStatusSucceeded }) names := make(map[string]bool, len(final.State.Artifacts)) @@ -2270,7 +2270,7 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { boom := errors.New("kaboom") af := DefineCustomAgent(reg, "detachErr", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: @@ -2314,10 +2314,10 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { close(release) snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusError + return s.Status == SnapshotStatusFailed }) - if !strings.Contains(snap.Error, "kaboom") { - t.Errorf("expected snapshot.Error to contain %q, got %q", "kaboom", snap.Error) + if snap.Error == nil || !strings.Contains(snap.Error.Message, "kaboom") { + t.Errorf("expected snapshot.Error.Message to contain %q, got %+v", "kaboom", snap.Error) } // Resuming from an errored detached snapshot is rejected. @@ -2333,14 +2333,14 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { // Client detaches, then calls AbortSnapshot. The store's status // subscriber notifies the runtime, which cancels the work context, and - // the finalizer rewrites the snapshot with status=canceled. + // the finalizer rewrites the snapshot with status=aborted. reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachAbort", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: @@ -2387,18 +2387,20 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if meta.Status != SnapshotStatusCanceled { - t.Errorf("AbortSnapshot status = %q, want canceled", meta.Status) + if meta.Status != SnapshotStatusAborted { + t.Errorf("AbortSnapshot status = %q, want aborted", meta.Status) } // The subscriber wakes the runtime, cancels work, and the finalizer - // rewrites the snapshot with the canceled status. + // rewrites the snapshot with the aborted status. finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusCanceled && s.UpdatedAt.After(s.CreatedAt) + return s.Status == SnapshotStatusAborted && s.UpdatedAt.After(s.CreatedAt) }) - if finalSnap.State.Custom.Counter != 0 { - // The flow only blocked on ctx — no state mutation expected. - t.Errorf("unexpected counter value in canceled snapshot: %d", finalSnap.State.Custom.Counter) + // The flow only blocked on ctx — no state mutation expected. State + // may be nil (when AbortSnapshot landed before the finalizer's write + // could populate it) or a populated zero-value struct. + if finalSnap.State != nil && finalSnap.State.Custom.Counter != 0 { + t.Errorf("unexpected counter value in aborted snapshot: %d", finalSnap.State.Custom.Counter) } } @@ -2409,7 +2411,7 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "syncStillWorks", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil @@ -2448,8 +2450,8 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { if err != nil { t.Fatalf("GetSnapshot: %v", err) } - if snap.Status != SnapshotStatusComplete { - t.Errorf("turn-end snapshot status = %q, want complete", snap.Status) + if snap.Status != SnapshotStatusSucceeded { + t.Errorf("turn-end snapshot status = %q, want succeeded", snap.Status) } if snap.Event != SnapshotEventTurnEnd { t.Errorf("turn-end snapshot event = %q, want %q", snap.Event, SnapshotEventTurnEnd) @@ -2467,7 +2469,7 @@ func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { exited := make(chan error, 1) af := DefineCustomAgent(reg, "syncCancel", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: @@ -2520,16 +2522,19 @@ func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ Event: SnapshotEventInvocationEnd, - Status: SnapshotStatusError, - Error: "underlying failure", - State: SessionState[testState]{}, + Status: SnapshotStatusFailed, + Error: &core.GenkitError{ + Status: core.INTERNAL, + Message: "underlying failure", + }, + State: &SessionState[testState]{}, }, nil }); err != nil { t.Fatalf("SaveSnapshot: %v", err) } af := DefineCustomAgent(reg, "resumeErrored", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(store), @@ -2561,7 +2566,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } af := DefineCustomAgent(reg, "transformedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("the secret is out")) return nil @@ -2591,8 +2596,8 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { if resp.SnapshotID != out.SnapshotID { t.Errorf("SnapshotID mismatch: got %q want %q", resp.SnapshotID, out.SnapshotID) } - if resp.Status != SnapshotStatusComplete { - t.Errorf("expected status=complete, got %q", resp.Status) + if resp.Status != SnapshotStatusSucceeded { + t.Errorf("expected status=succeeded, got %q", resp.Status) } if resp.State == nil { t.Fatal("expected state in response") @@ -2641,7 +2646,7 @@ func TestAgent_GetSnapshotAction_NoStore(t *testing.T) { reg := newTestRegistry(t) DefineCustomAgent(reg, "noStoreFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, ) @@ -2681,14 +2686,14 @@ func TestAgent_AgentMetadata(t *testing.T) { // Verify the metadata["agent"] payload on the flow's action descriptor // correctly reports stateManagement and abortable for each combination // of store capabilities. - noopFn := func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil } cases := []struct { name string define func(reg api.Registry, flowName string) - wantMgmt AgentMetadataStateManagement + wantMgmt AgentStateManagement wantAbortab bool }{ { @@ -2696,7 +2701,7 @@ func TestAgent_AgentMetadata(t *testing.T) { define: func(reg api.Registry, flowName string) { DefineCustomAgent(reg, flowName, noopFn) }, - wantMgmt: AgentMetadataStateManagementClient, + wantMgmt: AgentStateManagementClient, wantAbortab: false, }, { @@ -2705,7 +2710,7 @@ func TestAgent_AgentMetadata(t *testing.T) { DefineCustomAgent(reg, flowName, noopFn, WithSessionStore[testState](minimalStore[testState]{})) }, - wantMgmt: AgentMetadataStateManagementServer, + wantMgmt: AgentStateManagementServer, wantAbortab: false, }, { @@ -2714,7 +2719,7 @@ func TestAgent_AgentMetadata(t *testing.T) { DefineCustomAgent(reg, flowName, noopFn, WithSessionStore(NewInMemorySessionStore[testState]())) }, - wantMgmt: AgentMetadataStateManagementServer, + wantMgmt: AgentStateManagementServer, wantAbortab: true, }, } @@ -2757,7 +2762,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { reg := newTestRegistry(t) store := NewInMemorySessionStore[testState]() // implements SnapshotAborter DefineCustomAgent(reg, "fullCaps", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(store), @@ -2777,7 +2782,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { t.Run("no aborter capability → abort not registered", func(t *testing.T) { reg := newTestRegistry(t) DefineCustomAgent(reg, "minCaps", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore[testState](minimalStore[testState]{}), @@ -2806,7 +2811,7 @@ func TestAgent_StateTransform_ClientManagedState(t *testing.T) { } af := DefineCustomAgent(reg, "clientXformFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.UpdateCustom(func(s testState) testState { s.Counter = 7 @@ -2837,7 +2842,7 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "resumeDetachedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -2876,7 +2881,7 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { t.Fatalf("Output: %v", err) } finalSnap := waitForSnapshot(t, store, first.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusComplete + return s.Status == SnapshotStatusSucceeded }) if finalSnap.State.Custom.Counter != 1 { t.Fatalf("expected counter=1 in finalized snapshot, got %d", finalSnap.State.Custom.Counter) @@ -2906,7 +2911,7 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { t.Fatalf("AbortSnapshot(missing) = %+v, %v; want nil, nil", meta, err) } - // Pending → canceled, UpdatedAt advances. + // Pending → aborted, UpdatedAt advances. pending, err := store.SaveSnapshot(ctx, "snap-cas", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ @@ -2922,21 +2927,21 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if meta.Status != SnapshotStatusCanceled { - t.Errorf("status after first abort = %q, want canceled", meta.Status) + if meta.Status != SnapshotStatusAborted { + t.Errorf("status after first abort = %q, want aborted", meta.Status) } if !meta.UpdatedAt.After(pending.UpdatedAt) { t.Errorf("UpdatedAt did not advance: %v vs %v", meta.UpdatedAt, pending.UpdatedAt) } - // Idempotent: second abort returns canceled, no error, no further mutation. + // Idempotent: second abort returns aborted, no error, no further mutation. firstUpdate := meta.UpdatedAt meta2, err := store.AbortSnapshot(ctx, "snap-cas") if err != nil { t.Fatalf("AbortSnapshot (second): %v", err) } - if meta2.Status != SnapshotStatusCanceled { - t.Errorf("status after second abort = %q, want canceled", meta2.Status) + if meta2.Status != SnapshotStatusAborted { + t.Errorf("status after second abort = %q, want aborted", meta2.Status) } if !meta2.UpdatedAt.Equal(firstUpdate) { t.Errorf("UpdatedAt advanced on idempotent abort: %v vs %v", meta2.UpdatedAt, firstUpdate) @@ -2947,7 +2952,7 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ Event: SnapshotEventTurnEnd, - Status: SnapshotStatusComplete, + Status: SnapshotStatusSucceeded, }, nil }); err != nil { t.Fatalf("SaveSnapshot: %v", err) @@ -2956,15 +2961,15 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { if err != nil { t.Fatalf("AbortSnapshot on complete: %v", err) } - if meta3.Status != SnapshotStatusComplete { - t.Errorf("abort on complete returned status=%q, want complete", meta3.Status) + if meta3.Status != SnapshotStatusSucceeded { + t.Errorf("abort on complete returned status=%q, want succeeded", meta3.Status) } } func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { // An abort that lands while fn is still running but does not actually // stop fn (because fn does not observe ctx) must still result in - // status=canceled — the finalizer must not clobber canceled with + // status=aborted — the finalizer must not clobber aborted with // complete. The subscriber observes the status flip and the finalizer // reads the resulting flag. reg := newTestRegistry(t) @@ -2974,7 +2979,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "raceFinalize", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { select { case entered <- struct{}{}: @@ -2982,7 +2987,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { } <-fnRelease // Return cleanly without observing ctx. Without the - // subscriber/recheck, this would land status=complete and + // subscriber/recheck, this would land status=succeeded and // clobber the abort. return nil }) @@ -3023,10 +3028,10 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { close(fnRelease) finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusCanceled || s.Status == SnapshotStatusComplete + return s.Status == SnapshotStatusAborted || s.Status == SnapshotStatusSucceeded }) - if finalSnap.Status != SnapshotStatusCanceled { - t.Errorf("finalize clobbered canceled with %q", finalSnap.Status) + if finalSnap.Status != SnapshotStatusAborted { + t.Errorf("finalize clobbered aborted with %q", finalSnap.Status) } } @@ -3071,7 +3076,7 @@ func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { t.Fatal("did not receive initial status") } - // Abort flips status; subscriber observes canceled. + // Abort flips status; subscriber observes aborted. if _, err := store.AbortSnapshot(ctx, "snap-sub"); err != nil { t.Fatalf("AbortSnapshot: %v", err) } @@ -3080,8 +3085,8 @@ func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { if !ok { t.Fatal("channel closed before abort notification") } - if status != SnapshotStatusCanceled { - t.Errorf("status notification = %q, want canceled", status) + if status != SnapshotStatusAborted { + t.Errorf("status notification = %q, want aborted", status) } case <-time.After(time.Second): t.Fatal("did not receive abort notification") @@ -3106,7 +3111,7 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { store := NewInMemorySessionStore[testState]() af := DefineCustomAgent(reg, "abortNoop", - func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil @@ -3125,8 +3130,8 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if resp.Status != SnapshotStatusComplete { - t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusComplete, resp.Status) + if resp.Status != SnapshotStatusSucceeded { + t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusSucceeded, resp.Status) } // Confirm the store snapshot was not flipped. @@ -3134,7 +3139,7 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { if err != nil { t.Fatalf("GetSnapshot: %v", err) } - if snap.Status != SnapshotStatusComplete { - t.Errorf("snapshot status = %q after abort-on-terminal, want complete", snap.Status) + if snap.Status != SnapshotStatusSucceeded { + t.Errorf("snapshot status = %q after abort-on-terminal, want succeeded", snap.Status) } } diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 830d62c3d1..e73c310b18 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -20,46 +20,25 @@ package exp import ( "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "time" ) -// AgentMetadata is the value placed under metadata["agent"] on an agent's -// action descriptor. It exposes capability information so the Dev UI and -// other reflective callers can render the right surface (e.g. hide the -// Abort button when the configured store doesn't support it) without -// round-tripping through the reflection API. -type AgentMetadata struct { - // Abortable reports whether the agent's invocations can be aborted - // (true when the store implements [SnapshotAborter]). - Abortable bool `json:"abortable,omitempty"` - // StateManagement reports who owns session state. - StateManagement AgentMetadataStateManagement `json:"stateManagement,omitempty"` +// AbortSnapshotRequest is the input for the abortSnapshot companion action. +type AbortSnapshotRequest struct { + // SnapshotID identifies the snapshot whose invocation should be aborted. + SnapshotID string `json:"snapshotId"` } -// AgentMetadataStateManagement enumerates who owns session state for an -// agent: "server" (a [SessionStore] is configured and snapshots are -// persisted server-side) or "client" (no store; state flows through -// invocation init / output). -type AgentMetadataStateManagement string - -const ( - // AgentMetadataStateManagementServer indicates the agent is wired with - // a [SessionStore] and persists snapshots server-side. - AgentMetadataStateManagementServer AgentMetadataStateManagement = "server" - // AgentMetadataStateManagementClient indicates the agent has no store; - // session state is client-managed and round-trips through invocation - // init and output. - AgentMetadataStateManagementClient AgentMetadataStateManagement = "client" -) - -// Artifact represents a named collection of parts produced during a session. -// Examples: generated files, images, code snippets, diagrams, etc. -type Artifact struct { - // Metadata contains additional artifact-specific data. - Metadata map[string]any `json:"metadata,omitempty"` - // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). - Name string `json:"name,omitempty"` - // Parts contains the artifact content (text, media, etc.). - Parts []*ai.Part `json:"parts"` +// AbortSnapshotResponse is the output of the abortSnapshot companion action. +type AbortSnapshotResponse struct { + // SnapshotID echoes the requested snapshot ID. + SnapshotID string `json:"snapshotId"` + // Status is the snapshot's status after the abort attempt. For a + // pending snapshot this is [SnapshotStatusAborted]. For an + // already-terminal snapshot this is the existing terminal status (the + // abort is a no-op). + Status SnapshotStatus `json:"status,omitempty"` } // AgentInit is the input for starting an agent invocation. @@ -85,11 +64,34 @@ type AgentInput struct { Detach bool `json:"detach,omitempty"` // Messages contains the user's input for this turn. Messages []*ai.Message `json:"messages,omitempty"` - // ToolRestarts contains tool request parts to re-execute interrupted tools. - // Use [ai.ToolDef.RestartWith] to create these parts from an interrupted - // tool request. When set, the generate call resumes with these restarts - // instead of treating Messages as tool responses. - ToolRestarts []*ai.Part `json:"toolRestarts,omitempty"` + // Resume provides options for resuming an interrupted generation. + // Construct using [ai.ToolDef.RestartWith] / [ai.ToolDef.RespondWith] + // parts. When set, the generate call resumes with these parts instead + // of treating Messages as tool responses. + Resume *ToolResume `json:"resume,omitempty"` +} + +// ToolResume holds the parts that resume an interrupted agent turn. +// Mirrors [ai.GenerateActionResume] but is named for the tool-call +// callsite where it is set on an [AgentInput]. +type ToolResume struct { + // Respond contains tool response parts to send to the model when resuming. + Respond []*ai.Part `json:"respond,omitempty"` + // Restart contains tool request parts to restart when resuming. + Restart []*ai.Part `json:"restart,omitempty"` +} + +// AgentMetadata is the value placed under metadata["agent"] on an agent's +// action descriptor. It exposes capability information so the Dev UI and +// other reflective callers can render the right surface (e.g. hide the +// Abort button when the configured store doesn't support it) without +// round-tripping through the reflection API. +type AgentMetadata struct { + // Abortable reports whether the agent's invocations can be aborted + // (true when the store implements [SnapshotAborter]). + Abortable bool `json:"abortable,omitempty"` + // StateManagement reports who owns session state. + StateManagement AgentStateManagement `json:"stateManagement,omitempty"` } // AgentOutput is the output when an agent invocation completes. @@ -116,6 +118,22 @@ type AgentResult struct { Message *ai.Message `json:"message,omitempty"` } +// AgentStateManagement enumerates who owns session state for an +// agent: "server" (a [SessionStore] is configured and snapshots are +// persisted server-side) or "client" (no store; state flows through +// invocation init / output). +type AgentStateManagement string + +const ( + // AgentStateManagementServer indicates the agent is wired with + // a [SessionStore] and persists snapshots server-side. + AgentStateManagementServer AgentStateManagement = "server" + // AgentStateManagementClient indicates the agent has no store; + // session state is client-managed and round-trips through invocation + // init and output. + AgentStateManagementClient AgentStateManagement = "client" +) + // AgentStreamChunk represents a single item in the agent's output stream. // Multiple fields can be populated in a single chunk. type AgentStreamChunk[Stream any] struct { @@ -133,6 +151,51 @@ type AgentStreamChunk[Stream any] struct { TurnEnd *TurnEnd `json:"turnEnd,omitempty"` } +// Artifact represents a named collection of parts produced during a session. +// Examples: generated files, images, code snippets, diagrams, etc. +type Artifact struct { + // Metadata contains additional artifact-specific data. + Metadata map[string]any `json:"metadata,omitempty"` + // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). + Name string `json:"name,omitempty"` + // Parts contains the artifact content (text, media, etc.). + Parts []*ai.Part `json:"parts"` +} + +// GetSnapshotRequest is the input for an agent's getSnapshot companion +// action. The action is registered at `{agentName}/getSnapshot` when the +// agent is defined and is intended for Dev UI and client-side reconnect +// flows. +type GetSnapshotRequest struct { + // SnapshotID identifies the snapshot to fetch. + SnapshotID string `json:"snapshotId"` +} + +// GetSnapshotResponse is the output of the getSnapshot companion action. It +// is a client-facing view of the stored snapshot: identifying metadata plus +// the session state, with [WithStateTransform] applied if configured. +// +// Unlike the raw [SessionSnapshot], this response intentionally omits +// internal fields (parent ID, event) and does not leak the snapshot +// envelope beyond what callers need to repopulate a UI. +type GetSnapshotResponse[State any] struct { + // CreatedAt is when the snapshot record was first written. + CreatedAt time.Time `json:"createdAt,omitempty"` + // Error is the structured failure information; populated when Status + // is [SnapshotStatusFailed]. + Error *core.GenkitError `json:"error,omitempty"` + // SnapshotID echoes the requested snapshot ID. + SnapshotID string `json:"snapshotId"` + // State is the session state captured by the snapshot, after any + // configured transform. Empty when Status is pending or error. + State *SessionState[State] `json:"state,omitempty"` + // Status is the lifecycle state of the snapshot. See [SnapshotStatus]. + Status SnapshotStatus `json:"status,omitempty"` + // UpdatedAt is when the snapshot record was last written. Equals + // CreatedAt for snapshots that have not been rewritten. + UpdatedAt time.Time `json:"updatedAt,omitempty"` +} + // SessionState is the portable conversation state that flows between client // and server. It contains only the data needed for conversation continuity. type SessionState[State any] struct { @@ -141,7 +204,7 @@ type SessionState[State any] struct { // Custom is the user-defined state associated with this conversation. Custom State `json:"custom,omitempty"` // Messages is the conversation history (user/model exchanges). - // Does NOT include prompt-rendered messages, those are rendered fresh each turn. + // Does NOT include prompt-rendered messages — those are rendered fresh each turn. Messages []*ai.Message `json:"messages,omitempty"` } @@ -160,6 +223,58 @@ const ( SnapshotEventDetach SnapshotEvent = "detach" ) +// SnapshotMetadata is the metadata-only projection of a [SessionSnapshot]: +// identifying fields, lifecycle timestamps, and status. Returned by store +// operations that surface a snapshot's lifecycle state without paying for +// a full state read. +type SnapshotMetadata struct { + // CreatedAt is when the snapshot was first written. + CreatedAt time.Time `json:"createdAt"` + // Error is the structured failure information for a snapshot in + // [SnapshotStatusFailed]. + Error *core.GenkitError `json:"error,omitempty"` + // Event is what triggered this snapshot. + Event SnapshotEvent `json:"event"` + // ParentID is the ID of the previous snapshot in this timeline. + ParentID string `json:"parentId,omitempty"` + // SnapshotID is the unique identifier for this snapshot. + SnapshotID string `json:"snapshotId"` + // Status is the lifecycle state of this snapshot. + Status SnapshotStatus `json:"status,omitempty"` + // UpdatedAt is when the snapshot was last written. + UpdatedAt time.Time `json:"updatedAt,omitempty"` +} + +// SnapshotStatus describes the lifecycle state of a snapshot. Snapshots +// written for synchronous turns or invocations are always +// [SnapshotStatusSucceeded] (an empty value is also treated as succeeded +// for backwards compatibility). +// +// When a client sets [AgentInput.Detach], the server writes a single +// snapshot with [SnapshotStatusPending] (and empty state) and returns its +// ID immediately. Background processing then either rewrites that snapshot +// with the cumulative final state and [SnapshotStatusSucceeded] / +// [SnapshotStatusFailed] when the agent finishes, or with +// [SnapshotStatusAborted] if the client called abortSnapshot in the +// meantime. +type SnapshotStatus string + +const ( + // SnapshotStatusPending indicates a detached invocation is still + // processing the queued inputs. The snapshot will be rewritten with a + // terminal status once the flow exits. + SnapshotStatusPending SnapshotStatus = "pending" + // SnapshotStatusSucceeded indicates the snapshot captures a settled state. + SnapshotStatusSucceeded SnapshotStatus = "succeeded" + // SnapshotStatusAborted indicates the snapshot's invocation was + // aborted via the abortSnapshot companion action while detached. + SnapshotStatusAborted SnapshotStatus = "aborted" + // SnapshotStatusFailed indicates the invocation terminated with an error. + // The snapshot's Error field describes the failure and resume is + // rejected with that same error. + SnapshotStatusFailed SnapshotStatus = "failed" +) + // TurnEnd groups the signals emitted when an agent turn finishes. // A TurnEnd value is emitted exactly once per turn, regardless of whether a // snapshot was persisted. diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 0fecacec0c..27d30f7361 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -33,35 +33,6 @@ import ( // --- Snapshot --- -// SnapshotStatus describes the lifecycle state of a snapshot. Snapshots -// written for synchronous turns or invocations are always [SnapshotStatusComplete] -// (an empty value is also treated as complete for backwards compatibility). -// -// When a client sets [AgentInput.Detach], the server writes a single -// snapshot with [SnapshotStatusPending] (and empty state) and returns its -// ID immediately. Background processing then either rewrites that snapshot -// with the cumulative final state and [SnapshotStatusComplete] / -// [SnapshotStatusError] when the agent finishes, or with -// [SnapshotStatusCanceled] if the client called abortSnapshot in the -// meantime. -type SnapshotStatus string - -const ( - // SnapshotStatusPending indicates a detached invocation is still - // processing the queued inputs. The snapshot will be rewritten with a - // terminal status once the flow exits. - SnapshotStatusPending SnapshotStatus = "pending" - // SnapshotStatusComplete indicates the snapshot captures a settled state. - SnapshotStatusComplete SnapshotStatus = "complete" - // SnapshotStatusCanceled indicates the snapshot's invocation was - // aborted via the abortSnapshot companion action while detached. - SnapshotStatusCanceled SnapshotStatus = "canceled" - // SnapshotStatusError indicates the invocation terminated with an error. - // The snapshot's Error field describes the failure and resume is - // rejected with that same error. - SnapshotStatusError SnapshotStatus = "error" -) - // SessionSnapshot is a persisted point-in-time capture of session state. type SessionSnapshot[State any] struct { // SnapshotID is the unique identifier for this snapshot (UUID). @@ -70,23 +41,23 @@ type SessionSnapshot[State any] struct { ParentID string `json:"parentId,omitempty"` // CreatedAt is when the snapshot was created. CreatedAt time.Time `json:"createdAt"` - // UpdatedAt is when the snapshot was last written. For pending snapshots - // it equals CreatedAt; once the snapshot is finalized it reflects the - // terminal write. + // UpdatedAt is when the snapshot was last written. For pending + // snapshots it equals CreatedAt; once the snapshot is finalized it + // reflects the terminal write. UpdatedAt time.Time `json:"updatedAt,omitempty"` // Event is what triggered this snapshot. Event SnapshotEvent `json:"event"` // Status is the lifecycle state of this snapshot. Empty is treated as - // [SnapshotStatusComplete] for backwards compatibility. + // [SnapshotStatusSucceeded] for backwards compatibility. Status SnapshotStatus `json:"status,omitempty"` - // Error is the failure message for a snapshot in [SnapshotStatusError]. - // Empty otherwise. - Error string `json:"error,omitempty"` - // State is the actual conversation state. Empty on a pending snapshot - // (the live state is not yet committed; the background invocation is - // still processing queued inputs); populated on terminal snapshots - // with the cumulative final state. - State SessionState[State] `json:"state"` + // Error is the structured failure information for a snapshot in + // [SnapshotStatusFailed]. Nil otherwise. + Error *core.GenkitError `json:"error,omitempty"` + // State is the conversation state captured at this point. Nil on a + // pending snapshot (the live state is not yet committed; the + // background invocation is still processing queued inputs); populated + // on terminal snapshots with the cumulative final state. + State *SessionState[State] `json:"state,omitempty"` } // SnapshotContext provides context for snapshot decision callbacks. @@ -117,27 +88,6 @@ func applyTransform[State any](ctx context.Context, t StateTransform[State], sta // --- Session store --- -// SnapshotMetadata is the metadata-only projection of a [SessionSnapshot]: -// identifying fields, lifecycle timestamps, and status. Returned by store -// operations that surface a snapshot's lifecycle state without paying for -// a full state read. -type SnapshotMetadata struct { - // SnapshotID is the unique identifier for this snapshot. - SnapshotID string `json:"snapshotId"` - // ParentID is the ID of the previous snapshot in this timeline. - ParentID string `json:"parentId,omitempty"` - // CreatedAt is when the snapshot was first written. - CreatedAt time.Time `json:"createdAt"` - // UpdatedAt is when the snapshot was last written. - UpdatedAt time.Time `json:"updatedAt,omitempty"` - // Event is what triggered this snapshot. - Event SnapshotEvent `json:"event"` - // Status is the lifecycle state of this snapshot. - Status SnapshotStatus `json:"status,omitempty"` - // Error is the failure message for a snapshot in [SnapshotStatusError]. - Error string `json:"error,omitempty"` -} - // SnapshotReader retrieves snapshots. The minimum any session store must // implement to be used with [WithSessionStore]. type SnapshotReader[State any] interface { @@ -159,7 +109,7 @@ type SnapshotWriter[State any] interface { // from the existing row on update. // - UpdatedAt: stamped to the wall clock on every commit. // - Status: if the snapshot returned by fn has Status="", it is - // defaulted to [SnapshotStatusComplete] (the common case for + // defaulted to [SnapshotStatusSucceeded] (the common case for // synchronous turn-end and invocation-end writes). Callers // writing a pending row must set Status explicitly. // @@ -176,7 +126,7 @@ type SnapshotWriter[State any] interface { // populated), or nil if fn declined to write. SaveSnapshot( ctx context.Context, - id string, + snapshotID string, fn func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error), ) (*SessionSnapshot[State], error) } @@ -187,7 +137,7 @@ type SnapshotWriter[State any] interface { // function: // // - [SnapshotAborter.AbortSnapshot] flips a pending snapshot's status -// to canceled (typically called by the abortSnapshot companion +// to aborted (typically called by the abortSnapshot companion // action or directly by a Go caller holding the store). // // - [SnapshotAborter.OnSnapshotStatusChange] lets the agent runtime @@ -201,7 +151,7 @@ type SnapshotWriter[State any] interface { // "implemented one, not the other" footgun too easy to hit. type SnapshotAborter interface { // AbortSnapshot atomically transitions a snapshot from - // [SnapshotStatusPending] to [SnapshotStatusCanceled] and returns the + // [SnapshotStatusPending] to [SnapshotStatusAborted] and returns the // resulting metadata. If the snapshot is in any other status the // operation is a no-op and the existing metadata is returned. Returns // nil if the snapshot is not found. @@ -265,7 +215,7 @@ func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID return copySnapshot(snap) } -// AbortSnapshot atomically flips a pending snapshot to canceled. If the +// AbortSnapshot atomically flips a pending snapshot to aborted. If the // snapshot is already terminal the existing metadata is returned unchanged. // Returns nil if the snapshot is not found. func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (*SnapshotMetadata, error) { @@ -276,7 +226,7 @@ func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotI return nil, nil } if snap.Status == SnapshotStatusPending { - snap.Status = SnapshotStatusCanceled + snap.Status = SnapshotStatusAborted snap.UpdatedAt = time.Now() s.notifyLocked(snapshotID, snap.Status) } @@ -325,7 +275,7 @@ func (s *InMemorySessionStore[State]) SaveSnapshot( } next.UpdatedAt = now if next.Status == "" { - next.Status = SnapshotStatusComplete + next.Status = SnapshotStatusSucceeded } copied, err := copySnapshot(next) @@ -427,56 +377,6 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta // --- Snapshot companion actions --- -// GetSnapshotRequest is the input for an agent's getSnapshot companion -// action. The action is registered at `{agentName}/getSnapshot` when the -// agent is defined and is intended for Dev UI and client-side reconnect -// flows. -type GetSnapshotRequest struct { - // SnapshotID identifies the snapshot to fetch. - SnapshotID string `json:"snapshotId"` -} - -// GetSnapshotResponse is the output of the getSnapshot companion action. It -// is a client-facing view of the stored snapshot: identifying metadata plus -// the session state, with [WithStateTransform] applied if configured. -// -// Unlike the raw [SessionSnapshot], this response intentionally omits -// internal fields (parent ID, event) and does not leak the snapshot -// envelope beyond what callers need to repopulate a UI. -type GetSnapshotResponse[State any] struct { - // SnapshotID echoes the requested snapshot ID. - SnapshotID string `json:"snapshotId"` - // CreatedAt is when the snapshot record was first written. - CreatedAt time.Time `json:"createdAt,omitempty"` - // UpdatedAt is when the snapshot record was last written. Equals - // CreatedAt for snapshots that have not been rewritten. - UpdatedAt time.Time `json:"updatedAt,omitempty"` - // Status is the lifecycle state of the snapshot. See [SnapshotStatus]. - Status SnapshotStatus `json:"status,omitempty"` - // Error is populated when Status is [SnapshotStatusError]. - Error string `json:"error,omitempty"` - // State is the session state captured by the snapshot, after any - // configured transform. Empty when Status is pending or error. - State *SessionState[State] `json:"state,omitempty"` -} - -// AbortSnapshotRequest is the input for the abortSnapshot companion action. -type AbortSnapshotRequest struct { - // SnapshotID identifies the snapshot whose invocation should be aborted. - SnapshotID string `json:"snapshotId"` -} - -// AbortSnapshotResponse is the output of the abortSnapshot companion action. -type AbortSnapshotResponse struct { - // SnapshotID echoes the requested snapshot ID. - SnapshotID string `json:"snapshotId"` - // Status is the snapshot's status after the abort attempt. For a - // pending snapshot this is [SnapshotStatusCanceled]. For an - // already-terminal snapshot this is the existing terminal status (the - // abort is a no-op). - Status SnapshotStatus `json:"status,omitempty"` -} - // registerSnapshotActions registers the agent's companion actions: // // - The agent's name under [api.ActionTypeAgentSnapshot] — getSnapshot, @@ -515,7 +415,7 @@ func registerSnapshotActions[State any]( status := snap.Status if status == "" { - status = SnapshotStatusComplete + status = SnapshotStatusSucceeded } updatedAt := snap.UpdatedAt if updatedAt.IsZero() { @@ -529,8 +429,8 @@ func registerSnapshotActions[State any]( Status: status, Error: snap.Error, } - if status != SnapshotStatusError && status != SnapshotStatusPending { - resp.State = applyTransform(ctx, transform, &snap.State) + if status != SnapshotStatusFailed && status != SnapshotStatusPending { + resp.State = applyTransform(ctx, transform, snap.State) } return resp, nil }) diff --git a/go/core/error.go b/go/core/error.go index 4482fc21fd..38504998c3 100644 --- a/go/core/error.go +++ b/go/core/error.go @@ -18,6 +18,7 @@ package core import ( + "encoding/json" "errors" "fmt" "runtime/debug" @@ -36,13 +37,69 @@ type ReflectionError struct { } // GenkitError is the base error type for Genkit errors. +// +// On the wire, GenkitError marshals to and from the canonical Genkit +// error shape {status, message, details}, which mirrors the +// `RuntimeError` definition in the JSON schema. Fields that exist for +// in-process use (HTTPCode, Source, the wrapped error) are not +// serialized. type GenkitError struct { - Message string `json:"message"` // Exclude from default JSON if embedded elsewhere - Status StatusName `json:"status"` - HTTPCode int `json:"-"` // Exclude from default JSON - Details map[string]any `json:"details"` // Use map for arbitrary details - Source *string `json:"source,omitempty"` // Pointer for optional - originalError error // The wrapped error, if any + Message string // Wire field "message". + Status StatusName // Wire field "status". + HTTPCode int // Derived from Status; not serialized. + Details map[string]any // Wire field "details" (omitted when empty). + Source *string // In-process annotation; not serialized. + originalError error // The wrapped error, if any. +} + +// genkitErrorWire is the on-the-wire shape of a [GenkitError]; it +// matches the `RuntimeError` definition in the JSON schema. +type genkitErrorWire struct { + Status StatusName `json:"status"` + Message string `json:"message"` + Details map[string]any `json:"details,omitempty"` +} + +// MarshalJSON encodes a GenkitError in the canonical Genkit error wire +// format: {status, message, details}. +func (e *GenkitError) MarshalJSON() ([]byte, error) { + return json.Marshal(genkitErrorWire{ + Status: e.Status, + Message: e.Message, + Details: e.Details, + }) +} + +// UnmarshalJSON decodes a GenkitError from the canonical wire format +// and re-derives HTTPCode from Status. +func (e *GenkitError) UnmarshalJSON(data []byte) error { + var w genkitErrorWire + if err := json.Unmarshal(data, &w); err != nil { + return err + } + e.Status = w.Status + e.Message = w.Message + e.Details = w.Details + e.HTTPCode = HTTPStatusCode(w.Status) + return nil +} + +// AsGenkitError returns err as a *GenkitError, wrapping it in a fresh +// one with status INTERNAL if it isn't one already. Returns nil for a +// nil input. +func AsGenkitError(err error) *GenkitError { + if err == nil { + return nil + } + var ge *GenkitError + if errors.As(err, &ge) { + return ge + } + return &GenkitError{ + Status: INTERNAL, + Message: err.Error(), + HTTPCode: HTTPStatusCode(INTERNAL), + } } // UserFacingError is the base error type for user facing errors. diff --git a/go/core/error_test.go b/go/core/error_test.go index a2e26a25a8..498803f750 100644 --- a/go/core/error_test.go +++ b/go/core/error_test.go @@ -17,6 +17,7 @@ package core import ( + "encoding/json" "errors" "fmt" "net/http" @@ -163,6 +164,95 @@ func TestGenkitErrorToReflectionError(t *testing.T) { }) } +func TestGenkitErrorJSONRoundtrip(t *testing.T) { + t.Run("marshals canonical wire shape", func(t *testing.T) { + ge := &GenkitError{ + Status: NOT_FOUND, + Message: "missing", + Details: map[string]any{"id": "abc"}, + HTTPCode: 999, // not on the wire + Source: func() *string { s := "x"; return &s }(), // not on the wire + } + got, err := json.Marshal(ge) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `{"status":"NOT_FOUND","message":"missing","details":{"id":"abc"}}` + if string(got) != want { + t.Errorf("Marshal = %s, want %s", got, want) + } + }) + + t.Run("omits empty details", func(t *testing.T) { + ge := &GenkitError{Status: NOT_FOUND, Message: "missing"} + got, err := json.Marshal(ge) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `{"status":"NOT_FOUND","message":"missing"}` + if string(got) != want { + t.Errorf("Marshal = %s, want %s", got, want) + } + }) + + t.Run("unmarshals and derives HTTPCode", func(t *testing.T) { + raw := `{"status":"NOT_FOUND","message":"missing","details":{"id":"abc"}}` + var ge GenkitError + if err := json.Unmarshal([]byte(raw), &ge); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if ge.Status != NOT_FOUND { + t.Errorf("Status = %q, want %q", ge.Status, NOT_FOUND) + } + if ge.Message != "missing" { + t.Errorf("Message = %q, want %q", ge.Message, "missing") + } + if ge.HTTPCode != http.StatusNotFound { + t.Errorf("HTTPCode = %d, want %d", ge.HTTPCode, http.StatusNotFound) + } + if ge.Details["id"] != "abc" { + t.Errorf("Details[id] = %v, want %q", ge.Details["id"], "abc") + } + }) +} + +func TestAsGenkitError(t *testing.T) { + t.Run("nil returns nil", func(t *testing.T) { + if got := AsGenkitError(nil); got != nil { + t.Errorf("AsGenkitError(nil) = %+v, want nil", got) + } + }) + + t.Run("returns existing GenkitError unchanged", func(t *testing.T) { + orig := &GenkitError{Status: NOT_FOUND, Message: "missing"} + if got := AsGenkitError(orig); got != orig { + t.Errorf("expected same pointer, got %+v", got) + } + }) + + t.Run("unwraps nested GenkitError", func(t *testing.T) { + orig := &GenkitError{Status: NOT_FOUND, Message: "missing"} + wrapped := fmt.Errorf("wrap: %w", orig) + got := AsGenkitError(wrapped) + if got != orig { + t.Errorf("expected unwrapped pointer, got %+v", got) + } + }) + + t.Run("wraps plain error with INTERNAL", func(t *testing.T) { + got := AsGenkitError(errors.New("boom")) + if got.Status != INTERNAL { + t.Errorf("Status = %q, want INTERNAL", got.Status) + } + if got.Message != "boom" { + t.Errorf("Message = %q, want boom", got.Message) + } + if got.HTTPCode != http.StatusInternalServerError { + t.Errorf("HTTPCode = %d, want %d", got.HTTPCode, http.StatusInternalServerError) + } + }) +} + // testCustomError is a helper type for the errors.As subtest. type testCustomError struct { code int diff --git a/go/core/schemas.config b/go/core/schemas.config index be8e797a53..97fa9d7252 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1102,10 +1102,12 @@ GenerateActionOptions.output type *GenerateActionOutputConfig GenerateActionOptions.returnToolRequests type bool GenerateActionOptions.maxTurns type int GenerateActionOptions.use type []*MiddlewareRef + GenerateActionOptionsResume name GenerateActionResume GenerateActionOptionsResume.respond type []*Part GenerateActionOptionsResume.restart type []*Part + # GenerateActionOutputConfig GenerateActionOutputConfig.instructions type *string GenerateActionOutputConfig.jsonSchema type map[string]any @@ -1167,7 +1169,9 @@ GenkitErrorDataGenkitErrorDetails omit # Package configuration: ai/exp directory uses "exp" as Go package name. ai/exp name exp +exp import time exp import github.com/firebase/genkit/go/ai +exp import github.com/firebase/genkit/go/core # ---------------------------------------------------------------------------- # Artifact @@ -1220,12 +1224,33 @@ AgentInput.messages doc Messages contains the user's input for this turn. . -AgentInput.toolRestarts type []*ai.Part -AgentInput.toolRestarts doc -ToolRestarts contains tool request parts to re-execute interrupted tools. -Use [ai.ToolDef.RestartWith] to create these parts from an interrupted -tool request. When set, the generate call resumes with these restarts -instead of treating Messages as tool responses. +AgentInput.resume doc +Resume provides options for resuming an interrupted generation. +Construct using [ai.ToolDef.RestartWith] / [ai.ToolDef.RespondWith] +parts. When set, the generate call resumes with these parts instead +of treating Messages as tool responses. +. + +# AgentInputResume is the inline resume payload hoisted out of +# AgentInput.resume. Renamed to ToolResume so the type reads like a +# proper noun on agent inputs. +AgentInputResume pkg ai/exp +AgentInputResume name ToolResume +AgentInputResume.respond type []*ai.Part +AgentInputResume.restart type []*ai.Part + +AgentInputResume doc +ToolResume holds the parts that resume an interrupted agent turn. +Mirrors [ai.GenerateActionResume] but is named for the tool-call +callsite where it is set on an [AgentInput]. +. + +AgentInputResume.respond doc +Respond contains tool response parts to send to the model when resuming. +. + +AgentInputResume.restart doc +Restart contains tool request parts to restart when resuming. . # ---------------------------------------------------------------------------- @@ -1245,7 +1270,6 @@ SnapshotID loads state from a persisted snapshot. Mutually exclusive with State. . -AgentInit.state type *SessionState[State] AgentInit.state doc State provides direct state for the invocation. Mutually exclusive with SnapshotID. @@ -1288,7 +1312,6 @@ SnapshotID is the ID of the snapshot created at the end of this invocation. Empty if no snapshot was created (callback returned false or no store configured). . -AgentOutput.state type *SessionState[State] AgentOutput.state doc State contains the final conversation state. Only populated when state is client-managed (no store configured). @@ -1409,20 +1432,183 @@ terminal status once the background work finishes. . # ---------------------------------------------------------------------------- -# Snapshot lifecycle types (hand-written in go/ai/exp/session.go) +# Snapshot lifecycle types # ---------------------------------------------------------------------------- -# SnapshotStatus enum and the persisted snapshot envelope are hand-written -# alongside the Session type. The companion-action request/response types -# are also hand-written because they reference SessionState[State] with a -# Go type parameter, which the generator does not support. -SnapshotStatus omit -SessionSnapshot omit -SnapshotMetadata omit -GetSnapshotRequest omit -GetSnapshotResponse omit -AbortSnapshotRequest omit -AbortSnapshotResponse omit +# SnapshotStatus +SnapshotStatus pkg ai/exp + +SnapshotStatus doc +SnapshotStatus describes the lifecycle state of a snapshot. Snapshots +written for synchronous turns or invocations are always +[SnapshotStatusSucceeded] (an empty value is also treated as succeeded +for backwards compatibility). + +When a client sets [AgentInput.Detach], the server writes a single +snapshot with [SnapshotStatusPending] (and empty state) and returns its +ID immediately. Background processing then either rewrites that snapshot +with the cumulative final state and [SnapshotStatusSucceeded] / +[SnapshotStatusFailed] when the agent finishes, or with +[SnapshotStatusAborted] if the client called abortSnapshot in the +meantime. +. + +SnapshotStatusPending doc +SnapshotStatusPending indicates a detached invocation is still +processing the queued inputs. The snapshot will be rewritten with a +terminal status once the flow exits. +. + +SnapshotStatusSucceeded doc +SnapshotStatusSucceeded indicates the snapshot captures a settled state. +. + +SnapshotStatusAborted doc +SnapshotStatusAborted indicates the snapshot's invocation was +aborted via the abortSnapshot companion action while detached. +. + +SnapshotStatusFailed doc +SnapshotStatusFailed indicates the invocation terminated with an error. +The snapshot's Error field describes the failure and resume is +rejected with that same error. +. + +# SnapshotMetadata +SnapshotMetadata pkg ai/exp + +SnapshotMetadata doc +SnapshotMetadata is the metadata-only projection of a [SessionSnapshot]: +identifying fields, lifecycle timestamps, and status. Returned by store +operations that surface a snapshot's lifecycle state without paying for +a full state read. +. + +SnapshotMetadata.snapshotId noomitempty +SnapshotMetadata.snapshotId doc +SnapshotID is the unique identifier for this snapshot. +. + +SnapshotMetadata.parentId doc +ParentID is the ID of the previous snapshot in this timeline. +. + +SnapshotMetadata.createdAt type time.Time +SnapshotMetadata.createdAt noomitempty +SnapshotMetadata.createdAt doc +CreatedAt is when the snapshot was first written. +. + +SnapshotMetadata.updatedAt type time.Time +SnapshotMetadata.updatedAt doc +UpdatedAt is when the snapshot was last written. +. + +SnapshotMetadata.event noomitempty +SnapshotMetadata.event doc +Event is what triggered this snapshot. +. + +SnapshotMetadata.status doc +Status is the lifecycle state of this snapshot. +. + +SnapshotMetadata.error type *core.GenkitError +SnapshotMetadata.error doc +Error is the structured failure information for a snapshot in +[SnapshotStatusFailed]. +. + +# GetSnapshotRequest +GetSnapshotRequest pkg ai/exp + +GetSnapshotRequest doc +GetSnapshotRequest is the input for an agent's getSnapshot companion +action. The action is registered at `{agentName}/getSnapshot` when the +agent is defined and is intended for Dev UI and client-side reconnect +flows. +. + +GetSnapshotRequest.snapshotId noomitempty +GetSnapshotRequest.snapshotId doc +SnapshotID identifies the snapshot to fetch. +. + +# GetSnapshotResponse +GetSnapshotResponse pkg ai/exp +GetSnapshotResponse typeparams [State any] + +GetSnapshotResponse doc +GetSnapshotResponse is the output of the getSnapshot companion action. It +is a client-facing view of the stored snapshot: identifying metadata plus +the session state, with [WithStateTransform] applied if configured. + +Unlike the raw [SessionSnapshot], this response intentionally omits +internal fields (parent ID, event) and does not leak the snapshot +envelope beyond what callers need to repopulate a UI. +. + +GetSnapshotResponse.snapshotId noomitempty +GetSnapshotResponse.snapshotId doc +SnapshotID echoes the requested snapshot ID. +. + +GetSnapshotResponse.createdAt type time.Time +GetSnapshotResponse.createdAt doc +CreatedAt is when the snapshot record was first written. +. + +GetSnapshotResponse.updatedAt type time.Time +GetSnapshotResponse.updatedAt doc +UpdatedAt is when the snapshot record was last written. Equals +CreatedAt for snapshots that have not been rewritten. +. + +GetSnapshotResponse.status doc +Status is the lifecycle state of the snapshot. See [SnapshotStatus]. +. + +GetSnapshotResponse.error type *core.GenkitError +GetSnapshotResponse.error doc +Error is the structured failure information; populated when Status +is [SnapshotStatusFailed]. +. + +GetSnapshotResponse.state doc +State is the session state captured by the snapshot, after any +configured transform. Empty when Status is pending or error. +. + +# AbortSnapshotRequest +AbortSnapshotRequest pkg ai/exp + +AbortSnapshotRequest doc +AbortSnapshotRequest is the input for the abortSnapshot companion action. +. + +AbortSnapshotRequest.snapshotId noomitempty +AbortSnapshotRequest.snapshotId doc +SnapshotID identifies the snapshot whose invocation should be aborted. +. + +# AbortSnapshotResponse +AbortSnapshotResponse pkg ai/exp + +AbortSnapshotResponse doc +AbortSnapshotResponse is the output of the abortSnapshot companion action. +. + +AbortSnapshotResponse.snapshotId noomitempty +AbortSnapshotResponse.snapshotId doc +SnapshotID echoes the requested snapshot ID. +. + +AbortSnapshotResponse.status doc +Status is the snapshot's status after the abort attempt. For a +pending snapshot this is [SnapshotStatusAborted]. For an +already-terminal snapshot this is the existing terminal status (the +abort is a no-op). +. # ---------------------------------------------------------------------------- # AgentMetadata @@ -1447,22 +1633,22 @@ Abortable reports whether the agent's invocations can be aborted (true when the store implements [SnapshotAborter]). . -AgentMetadataStateManagement pkg ai/exp +AgentStateManagement pkg ai/exp -AgentMetadataStateManagement doc -AgentMetadataStateManagement enumerates who owns session state for an +AgentStateManagement doc +AgentStateManagement enumerates who owns session state for an agent: "server" (a [SessionStore] is configured and snapshots are persisted server-side) or "client" (no store; state flows through invocation init / output). . -AgentMetadataStateManagementServer doc -AgentMetadataStateManagementServer indicates the agent is wired with +AgentStateManagementServer doc +AgentStateManagementServer indicates the agent is wired with a [SessionStore] and persists snapshots server-side. . -AgentMetadataStateManagementClient doc -AgentMetadataStateManagementClient indicates the agent has no store; +AgentStateManagementClient doc +AgentStateManagementClient indicates the agent has no store; session state is client-managed and round-trips through invocation init and output. . diff --git a/go/core/status_types.go b/go/core/status_types.go index 0c74a14f8e..2e8aa5c803 100644 --- a/go/core/status_types.go +++ b/go/core/status_types.go @@ -39,7 +39,7 @@ const ( ABORTED StatusName = "ABORTED" OUT_OF_RANGE StatusName = "OUT_OF_RANGE" UNIMPLEMENTED StatusName = "UNIMPLEMENTED" - INTERNAL StatusName = "INTERNAL_SERVER_ERROR" + INTERNAL StatusName = "INTERNAL" UNAVAILABLE StatusName = "UNAVAILABLE" DATA_LOSS StatusName = "DATA_LOSS" ) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 0916ab0225..bc002bf410 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -531,8 +531,8 @@ func DefinePromptAgent[State, PromptIn any]( // minor version release. // // The provided function fn receives a [aix.Responder] for streaming output -// to the client and an [aix.AgentSession] for accessing conversation state. -// Call [aix.AgentSession.Run] to enter the turn loop, which blocks until the +// to the client and an [aix.SessionRunner] for accessing conversation state. +// Call [aix.SessionRunner.Run] to enter the turn loop, which blocks until the // client sends the next message. // // For agents backed by a prompt, use [DefineAgent] (inline) or @@ -551,7 +551,7 @@ func DefinePromptAgent[State, PromptIn any]( // Example: // // chatAgent := genkit.DefineCustomAgent(g, "chat", -// func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentResult, error) { +// func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { // var lastMessage *ai.Message // err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { // for result, err := range genkit.GenerateStream(ctx, g, diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 068dc49dbf..9f8eb94fa2 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -422,7 +422,7 @@ data: {"result":"hello-end"} resp := w.Result() body, _ := io.ReadAll(resp.Body) - expected := `data: {"error":{"status":"INTERNAL_SERVER_ERROR","message":"stream flow error","details":"streaming error"}} + expected := `data: {"error":{"status":"INTERNAL","message":"stream flow error","details":"streaming error"}} ` if string(body) != expected { diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index ca32ecc572..8d14dd4801 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -544,11 +544,20 @@ func (g *generator) typeExpr(s *Schema) (string, error) { if ic != nil && ic.name != "" { name = ic.name } + // If the target type is generic, append its type-args so the + // reference is well-formed (e.g. `*SessionState[State]`). The + // referencing type is responsible for declaring matching + // typeparams; otherwise the generated code won't compile. + // Callers can override this with an explicit `type` directive. + var typeArgs string + if ic != nil { + typeArgs = typeParamArgs(ic.typeparams) + } if s2.Enum != nil { - return name, nil + return name + typeArgs, nil } // If it's not an enum, it's a struct. Use a pointer to it. - return "*" + name, nil + return "*" + name + typeArgs, nil } // If there is no specified type, assume the schema represents any type. if s.Type.Any() == nil { @@ -665,6 +674,33 @@ func sortedKeys[K cmp.Ordered, V any](m map[K]V) []K { return keys } +// typeParamArgs converts a Go type-parameter list like "[State any]" or +// "[A any, B comparable]" into the matching type-argument list "[State]" +// or "[A, B]". Returns "" for an empty input. Used to forward the +// type-args of a generic target type onto a reference (e.g. turn a ref +// to `SessionState` into `SessionState[State]` when SessionState has +// `typeparams [State any]`). +func typeParamArgs(typeparams string) string { + inner := strings.TrimSpace(typeparams) + if inner == "" { + return "" + } + inner = strings.TrimPrefix(inner, "[") + inner = strings.TrimSuffix(inner, "]") + var names []string + for _, clause := range strings.Split(inner, ",") { + fields := strings.Fields(strings.TrimSpace(clause)) + if len(fields) == 0 { + continue + } + names = append(names, fields[0]) + } + if len(names) == 0 { + return "" + } + return "[" + strings.Join(names, ", ") + "]" +} + // config is the configuration for a schema file. // It describes modifications to the defaults of the code generator. type config struct { @@ -721,7 +757,11 @@ type extraField struct { // import // path of package to import (for packages only, may be repeated) // typeparams PARAMS -// Go type parameters to add to the type declaration (e.g., "[State any]") +// Go type parameters to add to the type declaration (e.g., "[State any]"). +// References to this type from other generated fields are +// automatically rewritten to include the matching type-args +// (e.g., "*SessionState[State]"). The referencing type must +// declare matching typeparams. // noomitempty // don't add omitempty to this field's json tag // field NAME TYPE diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen_test.go b/go/internal/cmd/jsonschemagen/jsonschemagen_test.go index d903481fa2..c2fa0cc5f3 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen_test.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen_test.go @@ -58,6 +58,27 @@ func Test(t *testing.T) { } } +func TestTypeParamArgs(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"", ""}, + {"[]", ""}, + {"[State any]", "[State]"}, + {" [State any] ", "[State]"}, + {"[A any, B comparable]", "[A, B]"}, + {"[K, V any]", "[K, V]"}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + if got := typeParamArgs(tt.in); got != tt.want { + t.Errorf("typeParamArgs(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + func TestSkipOmitEmpty(t *testing.T) { tests := []struct { name string diff --git a/go/plugins/googlegenai/errors.go b/go/plugins/googlegenai/errors.go index b77ccd0587..b086fb3d03 100644 --- a/go/plugins/googlegenai/errors.go +++ b/go/plugins/googlegenai/errors.go @@ -27,9 +27,8 @@ import ( // matches the one the server reported so status-aware middleware (retry, // fallback, ...) can reason about it. Non-APIError values pass through. // -// The SDK's Status string is a canonical Google / gRPC status name which, -// by design, already matches the string value of every [core.StatusName] -// constant except INTERNAL (our constant spells it "INTERNAL_SERVER_ERROR"). +// The SDK's Status string is a canonical Google / gRPC status name and so +// matches the string value of each [core.StatusName] constant directly. // When Status is missing or unrecognised the HTTP Code is the fallback. func wrapAPIError(err error) error { if err == nil { @@ -43,9 +42,6 @@ func wrapAPIError(err error) error { } func statusForAPIError(e genai.APIError) core.StatusName { - if e.Status == "INTERNAL" { - return core.INTERNAL - } s := core.StatusName(e.Status) if _, ok := core.StatusNameToCode[s]; ok { return s diff --git a/go/plugins/middleware/retry.go b/go/plugins/middleware/retry.go index a288a46ed1..9080c4b6b3 100644 --- a/go/plugins/middleware/retry.go +++ b/go/plugins/middleware/retry.go @@ -59,7 +59,7 @@ var sleepFunc = func(ctx context.Context, d time.Duration) error { // // By default, retries occur for non-[core.GenkitError] errors (e.g. network failures) // and for [core.GenkitError] errors whose status is one of UNAVAILABLE, DEADLINE_EXCEEDED, -// RESOURCE_EXHAUSTED, ABORTED, or INTERNAL_SERVER_ERROR. +// RESOURCE_EXHAUSTED, ABORTED, or INTERNAL. // // Usage: // diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index 0e647499d3..55668cc277 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -36,7 +36,7 @@ func main() { g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) chatAgent := genkit.DefineCustomAgent(g, "chat", - func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentResult, error) { + func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ From 13dbbe4a104028a5212ff4adffdb36b1e93a6b0e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 11 May 2026 18:46:49 -0700 Subject: [PATCH 064/141] chore: appease gofmt and prettier Auto-formats from `go fmt ./...` and `prettier --write` after the previous commit; no behavior changes. --- genkit-tools/common/src/types/agent.ts | 4 +--- go/core/error_test.go | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 63b649b8c7..ce694ca428 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -277,9 +277,7 @@ export type AbortSnapshotResponse = z.infer; * and output payloads. */ export const AgentStateManagementSchema = z.enum(['server', 'client']); -export type AgentStateManagement = z.infer< - typeof AgentStateManagementSchema ->; +export type AgentStateManagement = z.infer; /** * Zod schema for the agent capability metadata placed under diff --git a/go/core/error_test.go b/go/core/error_test.go index 498803f750..f1f2bdefce 100644 --- a/go/core/error_test.go +++ b/go/core/error_test.go @@ -170,7 +170,7 @@ func TestGenkitErrorJSONRoundtrip(t *testing.T) { Status: NOT_FOUND, Message: "missing", Details: map[string]any{"id": "abc"}, - HTTPCode: 999, // not on the wire + HTTPCode: 999, // not on the wire Source: func() *string { s := "x"; return &s }(), // not on the wire } got, err := json.Marshal(ge) From fae13cb311a9ce5f97d519031616fb0228a93402 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 11 May 2026 18:49:21 -0700 Subject: [PATCH 065/141] chore(py): regenerate _typing.py for snapshot/resume schema changes Picks up the recent JSON-schema renames (AgentStateManagement, SnapshotStatus succeeded/aborted/failed), the inlined AgentInput.resume shape, and the removal of the SessionSnapshot wire schema. --- .../genkit/src/genkit/_core/_typing.py | 51 +++++++------------ 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index e234053d42..9db57288cc 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -34,8 +34,8 @@ ) -class AgentMetadataStateManagement(StrEnum): - """AgentMetadataStateManagement data type class.""" +class AgentStateManagement(StrEnum): + """AgentStateManagement data type class.""" SERVER = 'server' CLIENT = 'client' @@ -53,9 +53,9 @@ class SnapshotStatus(StrEnum): """SnapshotStatus data type class.""" PENDING = 'pending' - COMPLETE = 'complete' - CANCELED = 'canceled' - ERROR = 'error' + SUCCEEDED = 'succeeded' + ABORTED = 'aborted' + FAILED = 'failed' class EvalStatusEnum(StrEnum): @@ -132,14 +132,14 @@ class AgentInput(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) detach: bool | None = None messages: list[MessageData] | None = None - tool_restarts: list[Part] | None = None + resume: Resume | None = None class AgentMetadata(GenkitModel): """Model for agentmetadata data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - state_management: AgentMetadataStateManagement = Field(...) + state_management: AgentStateManagement = Field(...) abortable: bool = Field(...) @@ -195,24 +195,10 @@ class GetSnapshotResponse(GenkitModel): created_at: str | None = None updated_at: str | None = None status: SnapshotStatus | None = None - error: str | None = None + error: Any | None = Field(default=None) state: SessionState | None = None -class SessionSnapshot(GenkitModel): - """Model for sessionsnapshot data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str = Field(...) - parent_id: str | None = None - created_at: str = Field(...) - updated_at: str | None = None - event: SnapshotEvent = Field(...) - status: SnapshotStatus | None = None - error: str | None = None - state: SessionState = Field(...) - - class SessionState(GenkitModel): """Model for sessionstate data.""" @@ -231,8 +217,8 @@ class SnapshotMetadata(GenkitModel): created_at: str = Field(...) updated_at: str | None = None event: SnapshotEvent = Field(...) - status: Any | None = Field(default=None) - error: str | None = None + status: SnapshotStatus | None = None + error: Any | None = Field(default=None) class TurnEnd(GenkitModel): @@ -940,6 +926,14 @@ class TraceMetadata(GenkitModel): timestamp: float = Field(...) +class Resume(GenkitModel): + """Model for resume data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + respond: list[ToolResponsePart] | None = None + restart: list[ToolRequestPart] | None = None + + class Details(GenkitModel): """Model for details data.""" @@ -963,15 +957,6 @@ class GenkitErrorDetails(GenkitModel): trace_id: str = Field(...) -class Resume(GenkitModel): - """Model for resume data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - respond: list[ToolResponsePart] | None = None - restart: list[ToolRequestPart] | None = None - metadata: Metadata | None = None - - class Supports(GenkitModel): """Model for supports data.""" From fb948a4f06c57eb0491073dad4f9ce1748c9fb13 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 11 May 2026 18:54:50 -0700 Subject: [PATCH 066/141] fix(py): prefer the superset when two inline schemas share a name Both `AgentInput.resume` and `GenerateActionOptions.resume` are inline object schemas that the Python generator collapses into a single `Resume` class. The previous "first one wins" rule picked the AgentInput shape (no `metadata`) and broke `_ai/_generate.py` and `_ai/_prompt.py`, which rely on `Resume.metadata` for the GenerateActionOptions case. Keep the schema with the larger property set so the generated dataclass captures the superset. --- py/packages/genkit/src/genkit/_core/_typing.py | 1 + py/tools/schema_to_typing/schema_to_typing.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 9db57288cc..1aec9195fe 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -932,6 +932,7 @@ class Resume(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) respond: list[ToolResponsePart] | None = None restart: list[ToolRequestPart] | None = None + metadata: Metadata | None = None class Details(GenkitModel): diff --git a/py/tools/schema_to_typing/schema_to_typing.py b/py/tools/schema_to_typing/schema_to_typing.py index 554f6e6c30..28cb83ebfa 100644 --- a/py/tools/schema_to_typing/schema_to_typing.py +++ b/py/tools/schema_to_typing/schema_to_typing.py @@ -156,7 +156,13 @@ def _typed_map_aliases(defs: dict) -> dict[str, str]: def _extract_inline_classes(schema: dict) -> dict[str, dict]: - """Extract inline object schemas to named classes (e.g. Score.details -> Details).""" + """Extract inline object schemas to named classes (e.g. Score.details -> Details). + + When two inline schemas across different parents share a derived class + name (e.g. ``resume`` on both ``AgentInput`` and ``GenerateActionOptions``), + keep the one with the larger property set so the generated dataclass + captures the superset of fields. + """ result = {} defs = schema.get('$defs') or {} @@ -164,8 +170,11 @@ def walk(props: dict) -> None: for prop_name, prop_schema in (props or {}).items(): if isinstance(prop_schema, dict) and prop_schema.get('type') == 'object' and '$ref' not in prop_schema: class_name = _pascal(prop_name) - if class_name not in defs and class_name not in result: - result[class_name] = prop_schema + if class_name not in defs: + existing = result.get(class_name) + new_props = prop_schema.get('properties') or {} + if existing is None or len(existing.get('properties') or {}) < len(new_props): + result[class_name] = prop_schema walk(prop_schema.get('properties', {})) for defn in defs.values(): From 0cdb4d3f4be68362ddda1b586aaaecdf6a8b31eb Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 07:47:00 -0700 Subject: [PATCH 067/141] refactor(go/exp): tighten AgentInit validation, share SessionSnapshot, slim AbortSnapshot - AgentInit: reject state-with-server-managed-agent and snapshotId-with-client-managed-agent in addition to both-set; add coverage in TestLoadSession_AgentInitValidation. - SessionSnapshot promoted to a shared schema (Zod + JSON schema + Python + Go gen); the hand-written Go struct in session.go is gone, and the error field validates the canonical {status, message, details} wire shape. - AbortSnapshot now returns (SnapshotStatus, error). SnapshotMetadata and the snapshotMetadata projection helper are removed (Zod, JSON schema, Python, Go). Empty status + nil error continues to signal "not found"; the abort companion action maps that to core.NOT_FOUND, now exercised by TestAgent_AbortAction_NotFound. - Companion actions (getSnapshot, abortSnapshot) are no longer registered when the agent has no SessionStore configured. Updated TestAgent_GetSnapshotAction_NoStore to verify neither action is registered in the client-managed case. --- genkit-tools/common/src/types/agent.ts | 33 +++- genkit-tools/genkit-schema.json | 77 +++++--- go/ai/exp/agent.go | 6 +- go/ai/exp/agent_test.go | 185 ++++++++++++++---- go/ai/exp/gen.go | 69 ++++--- go/ai/exp/session.go | 94 +++------ go/core/schemas.config | 126 +++++++----- .../genkit/src/genkit/_core/_typing.py | 41 ++-- 8 files changed, 396 insertions(+), 235 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index ce694ca428..3361d94566 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -187,27 +187,42 @@ export const AgentStreamChunkSchema = z.object({ export type AgentStreamChunk = z.infer; /** - * Zod schema for the metadata projection of a session snapshot. It exists - * so callers can identify a snapshot and check its lifecycle status without - * paying for a full state read. + * Zod schema for a persisted point-in-time capture of session state. It is + * the canonical record written to and read from a session store; the wire + * representation is shared across language runtimes and the Dev UI. */ -export const SnapshotMetadataSchema = z.object({ +export const SessionSnapshotSchema = z.object({ /** Unique identifier for this snapshot (UUID). */ snapshotId: z.string(), /** ID of the previous snapshot in this timeline. */ parentId: z.string().optional(), /** When the snapshot was first written (RFC 3339). */ createdAt: z.string(), - /** When the snapshot was last written (RFC 3339). */ + /** When the snapshot was last written (RFC 3339). Equals `createdAt` until rewritten. */ updatedAt: z.string().optional(), /** What triggered this snapshot. */ event: SnapshotEventSchema, - /** Lifecycle state of this snapshot. Empty is treated as `complete`. */ + /** Lifecycle state of this snapshot. Empty is treated as `succeeded`. */ status: SnapshotStatusSchema.optional(), - /** Structured failure information for a snapshot in `error` status. */ - error: z.any().optional(), + /** Structured failure information for a snapshot in `failed` status. */ + error: z + .object({ + /** Canonical status name (e.g. `INTERNAL`, `FAILED_PRECONDITION`). */ + status: z.string(), + /** Human-readable error message. */ + message: z.string(), + /** Optional structured details describing the failure. */ + details: z.any().optional(), + }) + .optional(), + /** + * Conversation state captured at this point. Empty on a pending snapshot + * (the live state is not yet committed); populated on terminal snapshots + * with the cumulative final state. + */ + state: SessionStateSchema.optional(), }); -export type SnapshotMetadata = z.infer; +export type SessionSnapshot = z.infer; /** * Zod schema for the input of an agent's `getSnapshot` companion action. diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index a2d4fe9c33..1d41d390a7 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -207,34 +207,7 @@ ], "additionalProperties": false }, - "SessionState": { - "type": "object", - "properties": { - "messages": { - "type": "array", - "items": { - "$ref": "#/$defs/Message" - } - }, - "custom": {}, - "artifacts": { - "type": "array", - "items": { - "$ref": "#/$defs/Artifact" - } - } - }, - "additionalProperties": false - }, - "SnapshotEvent": { - "type": "string", - "enum": [ - "turnEnd", - "invocationEnd", - "detach" - ] - }, - "SnapshotMetadata": { + "SessionSnapshot": { "type": "object", "properties": { "snapshotId": { @@ -255,7 +228,26 @@ "status": { "$ref": "#/$defs/SnapshotStatus" }, - "error": {} + "error": { + "type": "object", + "properties": { + "status": { + "type": "string" + }, + "message": { + "type": "string" + }, + "details": {} + }, + "required": [ + "status", + "message" + ], + "additionalProperties": false + }, + "state": { + "$ref": "#/$defs/SessionState" + } }, "required": [ "snapshotId", @@ -264,6 +256,33 @@ ], "additionalProperties": false }, + "SessionState": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "$ref": "#/$defs/Message" + } + }, + "custom": {}, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + } + }, + "additionalProperties": false + }, + "SnapshotEvent": { + "type": "string", + "enum": [ + "turnEnd", + "invocationEnd", + "detach" + ] + }, "SnapshotStatus": { "type": "string", "enum": [ diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 902d2a1880..f067d3b4ef 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -706,6 +706,10 @@ func loadSession[State any]( if init.SnapshotID == "" { if init.State != nil { + if store != nil { + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "state provided but agent has a session store configured (server-managed state); use snapshot ID instead") + } s.state = *init.State } return s, nil, nil @@ -713,7 +717,7 @@ func loadSession[State any]( if store == nil { return nil, nil, core.NewError(core.FAILED_PRECONDITION, - "snapshot ID %q provided but no session store configured", init.SnapshotID) + "snapshot ID %q provided but agent has no session store configured (client-managed state); use state instead", init.SnapshotID) } snap, err := store.GetSnapshot(ctx, init.SnapshotID) if err != nil { diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 32bfb39e48..665297c090 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -2383,12 +2383,12 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { // Abort via the store. The local caller already has the store // reference from WithSessionStore. - meta, err := store.AbortSnapshot(context.Background(), out.SnapshotID) + status, err := store.AbortSnapshot(context.Background(), out.SnapshotID) if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if meta.Status != SnapshotStatusAborted { - t.Errorf("AbortSnapshot status = %q, want aborted", meta.Status) + if status != SnapshotStatusAborted { + t.Errorf("AbortSnapshot status = %q, want aborted", status) } // The subscriber wakes the runtime, cancels work, and the finalizer @@ -2643,6 +2643,8 @@ func TestInMemorySessionStore_GetSnapshot_NotFound(t *testing.T) { } func TestAgent_GetSnapshotAction_NoStore(t *testing.T) { + // With no SessionStore configured, neither companion action should + // be registered: there is no server-side snapshot to fetch or abort. reg := newTestRegistry(t) DefineCustomAgent(reg, "noStoreFlow", @@ -2651,20 +2653,96 @@ func TestAgent_GetSnapshotAction_NoStore(t *testing.T) { }, ) - // Action remains registered even without a store; it returns - // FAILED_PRECONDITION when invoked. - action := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + getAction := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( reg, api.ActionTypeAgentSnapshot, "noStoreFlow") - if action == nil { - t.Fatal("getSnapshot action should be registered even without a store") + if getAction != nil { + t.Error("getSnapshot action should NOT be registered without a store") } - _, err := action.Run(context.Background(), &GetSnapshotRequest{SnapshotID: "any"}, nil) - if err == nil { - t.Fatal("expected error when store is not configured") + abortAction := core.ResolveActionFor[*AbortSnapshotRequest, *AbortSnapshotResponse, struct{}, struct{}]( + reg, api.ActionTypeAgentAbort, "noStoreFlow") + if abortAction != nil { + t.Error("abortSnapshot action should NOT be registered without a store") } - if !strings.Contains(err.Error(), "no session store configured") { - t.Errorf("unexpected error: %v", err) +} + +func TestLoadSession_AgentInitValidation(t *testing.T) { + // loadSession enforces the AgentInit invariants: + // - snapshotId and state are mutually exclusive, + // - snapshotId requires a store (server-managed state), + // - state requires the absence of a store (client-managed state). + ctx := context.Background() + store := NewInMemorySessionStore[testState]() + state := &SessionState[testState]{Custom: testState{Counter: 1}} + + cases := []struct { + name string + init *AgentInit[testState] + store SessionStore[testState] + wantErr string + }{ + { + name: "both snapshotId and state set", + init: &AgentInit[testState]{SnapshotID: "snap-1", State: state}, + store: store, + wantErr: "mutually exclusive", + }, + { + name: "both set, no store", + init: &AgentInit[testState]{SnapshotID: "snap-1", State: state}, + store: nil, + wantErr: "mutually exclusive", + }, + { + name: "state with server-managed agent", + init: &AgentInit[testState]{State: state}, + store: store, + wantErr: "server-managed state", + }, + { + name: "snapshotId with client-managed agent", + init: &AgentInit[testState]{SnapshotID: "snap-1"}, + store: nil, + wantErr: "client-managed state", + }, } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := loadSession(ctx, tc.init, tc.store) + if err == nil { + t.Fatalf("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("error %q does not contain %q", err.Error(), tc.wantErr) + } + }) + } + + t.Run("empty init with server store is allowed", func(t *testing.T) { + sess, snap, err := loadSession(ctx, &AgentInit[testState]{}, store) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sess == nil { + t.Fatal("expected session, got nil") + } + if snap != nil { + t.Errorf("expected no snapshot, got %+v", snap) + } + }) + + t.Run("empty init with no store is allowed", func(t *testing.T) { + sess, snap, err := loadSession(ctx, &AgentInit[testState]{}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sess == nil { + t.Fatal("expected session, got nil") + } + if snap != nil { + t.Errorf("expected no snapshot, got %+v", snap) + } + }) } // minimalStore is a SessionStore that does NOT implement SnapshotAborter. @@ -2800,6 +2878,37 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { }) } +func TestAgent_AbortAction_NotFound(t *testing.T) { + // The store's "not found" sentinel (empty status, nil error) must + // surface as a core.NOT_FOUND GenkitError on the abort companion + // action so callers (Dev UI, remote clients) see a proper status. + reg := newTestRegistry(t) + DefineCustomAgent(reg, "missingFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, nil + }, + WithSessionStore(NewInMemorySessionStore[testState]()), + ) + + abortAction := core.ResolveActionFor[*AbortSnapshotRequest, *AbortSnapshotResponse, struct{}, struct{}]( + reg, api.ActionTypeAgentAbort, "missingFlow") + if abortAction == nil { + t.Fatal("abortSnapshot action should be registered") + } + + _, err := abortAction.Run(context.Background(), &AbortSnapshotRequest{SnapshotID: "no-such-snap"}, nil) + if err == nil { + t.Fatal("expected error for missing snapshot, got nil") + } + var ge *core.GenkitError + if !errors.As(err, &ge) { + t.Fatalf("expected *core.GenkitError, got %T: %v", err, err) + } + if ge.Status != core.NOT_FOUND { + t.Errorf("status = %q, want %q", ge.Status, core.NOT_FOUND) + } +} + func TestAgent_StateTransform_ClientManagedState(t *testing.T) { reg := newTestRegistry(t) @@ -2906,12 +3015,12 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { ctx := context.Background() store := NewInMemorySessionStore[testState]() - // Abort on missing snapshot returns nil metadata, no error. - if meta, err := store.AbortSnapshot(ctx, "nope"); err != nil || meta != nil { - t.Fatalf("AbortSnapshot(missing) = %+v, %v; want nil, nil", meta, err) + // Abort on missing snapshot returns empty status, no error. + if status, err := store.AbortSnapshot(ctx, "nope"); err != nil || status != "" { + t.Fatalf("AbortSnapshot(missing) = %q, %v; want \"\", nil", status, err) } - // Pending → aborted, UpdatedAt advances. + // Pending → aborted, UpdatedAt advances (verified via GetSnapshot). pending, err := store.SaveSnapshot(ctx, "snap-cas", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ @@ -2923,28 +3032,36 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { t.Fatalf("SaveSnapshot: %v", err) } time.Sleep(time.Millisecond) // ensure measurable UpdatedAt delta - meta, err := store.AbortSnapshot(ctx, "snap-cas") + status, err := store.AbortSnapshot(ctx, "snap-cas") if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if meta.Status != SnapshotStatusAborted { - t.Errorf("status after first abort = %q, want aborted", meta.Status) + if status != SnapshotStatusAborted { + t.Errorf("status after first abort = %q, want aborted", status) } - if !meta.UpdatedAt.After(pending.UpdatedAt) { - t.Errorf("UpdatedAt did not advance: %v vs %v", meta.UpdatedAt, pending.UpdatedAt) + afterFirst, err := store.GetSnapshot(ctx, "snap-cas") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if !afterFirst.UpdatedAt.After(pending.UpdatedAt) { + t.Errorf("UpdatedAt did not advance: %v vs %v", afterFirst.UpdatedAt, pending.UpdatedAt) } // Idempotent: second abort returns aborted, no error, no further mutation. - firstUpdate := meta.UpdatedAt - meta2, err := store.AbortSnapshot(ctx, "snap-cas") + firstUpdate := afterFirst.UpdatedAt + status2, err := store.AbortSnapshot(ctx, "snap-cas") if err != nil { t.Fatalf("AbortSnapshot (second): %v", err) } - if meta2.Status != SnapshotStatusAborted { - t.Errorf("status after second abort = %q, want aborted", meta2.Status) + if status2 != SnapshotStatusAborted { + t.Errorf("status after second abort = %q, want aborted", status2) } - if !meta2.UpdatedAt.Equal(firstUpdate) { - t.Errorf("UpdatedAt advanced on idempotent abort: %v vs %v", meta2.UpdatedAt, firstUpdate) + afterSecond, err := store.GetSnapshot(ctx, "snap-cas") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if !afterSecond.UpdatedAt.Equal(firstUpdate) { + t.Errorf("UpdatedAt advanced on idempotent abort: %v vs %v", afterSecond.UpdatedAt, firstUpdate) } // Abort on terminal status is a no-op that returns the existing status. @@ -2957,12 +3074,12 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { }); err != nil { t.Fatalf("SaveSnapshot: %v", err) } - meta3, err := store.AbortSnapshot(ctx, "snap-complete") + status3, err := store.AbortSnapshot(ctx, "snap-complete") if err != nil { t.Fatalf("AbortSnapshot on complete: %v", err) } - if meta3.Status != SnapshotStatusSucceeded { - t.Errorf("abort on complete returned status=%q, want succeeded", meta3.Status) + if status3 != SnapshotStatusSucceeded { + t.Errorf("abort on complete returned status=%q, want succeeded", status3) } } @@ -3126,12 +3243,12 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { t.Fatalf("RunText: %v", err) } - resp, err := store.AbortSnapshot(ctx, out.SnapshotID) + status, err := store.AbortSnapshot(ctx, out.SnapshotID) if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if resp.Status != SnapshotStatusSucceeded { - t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusSucceeded, resp.Status) + if status != SnapshotStatusSucceeded { + t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusSucceeded, status) } // Confirm the store snapshot was not flipped. diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index e73c310b18..20c78a5f26 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -42,13 +42,22 @@ type AbortSnapshotResponse struct { } // AgentInit is the input for starting an agent invocation. -// Provide either SnapshotID (to load from store) or State (direct state). +// Exactly one of SnapshotID or State may be set, and the choice must match +// the agent's state management: +// - Server-managed state (a session store is configured): callers must +// use SnapshotID; sending State is rejected. +// - Client-managed state (no session store): callers must use State; +// sending SnapshotID is rejected. +// Sending both fields is always rejected. Sending neither starts a fresh +// invocation with empty state. type AgentInit[State any] struct { - // SnapshotID loads state from a persisted snapshot. - // Mutually exclusive with State. + // SnapshotID loads state from a persisted snapshot. Only valid when the + // agent is server-managed (a session store is configured). Mutually + // exclusive with State. SnapshotID string `json:"snapshotId,omitempty"` - // State provides direct state for the invocation. - // Mutually exclusive with SnapshotID. + // State provides direct state for the invocation. Only valid when the + // agent is client-managed (no session store). Mutually exclusive with + // SnapshotID. State *SessionState[State] `json:"state,omitempty"` } @@ -196,6 +205,34 @@ type GetSnapshotResponse[State any] struct { UpdatedAt time.Time `json:"updatedAt,omitempty"` } +// SessionSnapshot is a persisted point-in-time capture of session state. It +// is the canonical record written to and read from a [SessionStore]. +type SessionSnapshot[State any] struct { + // CreatedAt is when the snapshot was created. + CreatedAt time.Time `json:"createdAt"` + // Error is the structured failure information for a snapshot in + // [SnapshotStatusFailed]. Nil otherwise. + Error *core.GenkitError `json:"error,omitempty"` + // Event is what triggered this snapshot. + Event SnapshotEvent `json:"event"` + // ParentID is the ID of the previous snapshot in this timeline. + ParentID string `json:"parentId,omitempty"` + // SnapshotID is the unique identifier for this snapshot (UUID). + SnapshotID string `json:"snapshotId"` + // State is the conversation state captured at this point. Nil on a + // pending snapshot (the live state is not yet committed; the background + // invocation is still processing queued inputs); populated on terminal + // snapshots with the cumulative final state. + State *SessionState[State] `json:"state,omitempty"` + // Status is the lifecycle state of this snapshot. Empty is treated as + // [SnapshotStatusSucceeded] for backwards compatibility. + Status SnapshotStatus `json:"status,omitempty"` + // UpdatedAt is when the snapshot was last written. For pending snapshots + // it equals CreatedAt; once the snapshot is finalized it reflects the + // terminal write. + UpdatedAt time.Time `json:"updatedAt,omitempty"` +} + // SessionState is the portable conversation state that flows between client // and server. It contains only the data needed for conversation continuity. type SessionState[State any] struct { @@ -223,28 +260,6 @@ const ( SnapshotEventDetach SnapshotEvent = "detach" ) -// SnapshotMetadata is the metadata-only projection of a [SessionSnapshot]: -// identifying fields, lifecycle timestamps, and status. Returned by store -// operations that surface a snapshot's lifecycle state without paying for -// a full state read. -type SnapshotMetadata struct { - // CreatedAt is when the snapshot was first written. - CreatedAt time.Time `json:"createdAt"` - // Error is the structured failure information for a snapshot in - // [SnapshotStatusFailed]. - Error *core.GenkitError `json:"error,omitempty"` - // Event is what triggered this snapshot. - Event SnapshotEvent `json:"event"` - // ParentID is the ID of the previous snapshot in this timeline. - ParentID string `json:"parentId,omitempty"` - // SnapshotID is the unique identifier for this snapshot. - SnapshotID string `json:"snapshotId"` - // Status is the lifecycle state of this snapshot. - Status SnapshotStatus `json:"status,omitempty"` - // UpdatedAt is when the snapshot was last written. - UpdatedAt time.Time `json:"updatedAt,omitempty"` -} - // SnapshotStatus describes the lifecycle state of a snapshot. Snapshots // written for synchronous turns or invocations are always // [SnapshotStatusSucceeded] (an empty value is also treated as succeeded diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 27d30f7361..72362ad50c 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -33,33 +33,6 @@ import ( // --- Snapshot --- -// SessionSnapshot is a persisted point-in-time capture of session state. -type SessionSnapshot[State any] struct { - // SnapshotID is the unique identifier for this snapshot (UUID). - SnapshotID string `json:"snapshotId"` - // ParentID is the ID of the previous snapshot in this timeline. - ParentID string `json:"parentId,omitempty"` - // CreatedAt is when the snapshot was created. - CreatedAt time.Time `json:"createdAt"` - // UpdatedAt is when the snapshot was last written. For pending - // snapshots it equals CreatedAt; once the snapshot is finalized it - // reflects the terminal write. - UpdatedAt time.Time `json:"updatedAt,omitempty"` - // Event is what triggered this snapshot. - Event SnapshotEvent `json:"event"` - // Status is the lifecycle state of this snapshot. Empty is treated as - // [SnapshotStatusSucceeded] for backwards compatibility. - Status SnapshotStatus `json:"status,omitempty"` - // Error is the structured failure information for a snapshot in - // [SnapshotStatusFailed]. Nil otherwise. - Error *core.GenkitError `json:"error,omitempty"` - // State is the conversation state captured at this point. Nil on a - // pending snapshot (the live state is not yet committed; the - // background invocation is still processing queued inputs); populated - // on terminal snapshots with the cumulative final state. - State *SessionState[State] `json:"state,omitempty"` -} - // SnapshotContext provides context for snapshot decision callbacks. type SnapshotContext[State any] struct { // State is the current state that will be snapshotted if the callback returns true. @@ -147,20 +120,20 @@ type SnapshotWriter[State any] interface { // They are bundled because neither is useful alone: flipping status // with no observer means the running fn never learns it was aborted; // observing without a way to trigger the flip means no abort can -// happen. Splitting them into separate interfaces made the -// "implemented one, not the other" footgun too easy to hit. +// happen. type SnapshotAborter interface { // AbortSnapshot atomically transitions a snapshot from // [SnapshotStatusPending] to [SnapshotStatusAborted] and returns the - // resulting metadata. If the snapshot is in any other status the - // operation is a no-op and the existing metadata is returned. Returns - // nil if the snapshot is not found. + // resulting status. If the snapshot is in any other status the + // operation is a no-op and the existing status is returned. Returns + // an empty status with a nil error if the snapshot is not found, so + // callers can distinguish "not found" from a real error. // // Implementations must perform the read-and-write atomically (e.g., a // transaction or a compare-and-swap). The agent's abortSnapshot // action and finalizer rely on this to avoid a pending row being // clobbered by a racing terminal write. - AbortSnapshot(ctx context.Context, snapshotID string) (*SnapshotMetadata, error) + AbortSnapshot(ctx context.Context, snapshotID string) (SnapshotStatus, error) // OnSnapshotStatusChange returns a channel that yields the snapshot's // status whenever it changes. The first value (if any) reflects the @@ -216,21 +189,21 @@ func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID } // AbortSnapshot atomically flips a pending snapshot to aborted. If the -// snapshot is already terminal the existing metadata is returned unchanged. -// Returns nil if the snapshot is not found. -func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (*SnapshotMetadata, error) { +// snapshot is already terminal the existing status is returned unchanged. +// Returns an empty status if the snapshot is not found. +func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (SnapshotStatus, error) { s.mu.Lock() defer s.mu.Unlock() snap, ok := s.snapshots[snapshotID] if !ok { - return nil, nil + return "", nil } if snap.Status == SnapshotStatusPending { snap.Status = SnapshotStatusAborted snap.UpdatedAt = time.Now() s.notifyLocked(snapshotID, snap.Status) } - return snapshotMetadata(snap), nil + return snap.Status, nil } // SaveSnapshot atomically reads, applies fn, and persists. See the @@ -346,19 +319,6 @@ func (s *InMemorySessionStore[State]) notifyLocked(snapshotID string, status Sna } } -// snapshotMetadata projects the metadata fields of a snapshot. -func snapshotMetadata[State any](snap *SessionSnapshot[State]) *SnapshotMetadata { - return &SnapshotMetadata{ - SnapshotID: snap.SnapshotID, - ParentID: snap.ParentID, - CreatedAt: snap.CreatedAt, - UpdatedAt: snap.UpdatedAt, - Event: snap.Event, - Status: snap.Status, - Error: snap.Error, - } -} - // copySnapshot creates a deep copy of a snapshot using JSON marshaling. func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[State], error) { if snap == nil { @@ -377,31 +337,33 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta // --- Snapshot companion actions --- -// registerSnapshotActions registers the agent's companion actions: +// registerSnapshotActions registers the agent's companion actions when +// the agent has a [SessionStore] configured: // // - The agent's name under [api.ActionTypeAgentSnapshot] — getSnapshot, -// registered whenever a [SessionStore] is configured. The action is // the remote counterpart to [SessionStore.GetSnapshot] for Dev UI and -// non-Go clients; local Go callers use the store reference directly. +// non-Go clients. Local Go callers use the store reference directly. // // - The agent's name under [api.ActionTypeAgentAbort] — abortSnapshot, -// registered only when the store implements [SnapshotAborter] (which -// bundles both the abort trigger and the status-change subscription -// needed for the runtime to react). Surfacing the action only when -// the capability is present keeps the reflected API aligned with -// what the store can actually do. +// registered only when the store also implements [SnapshotAborter] +// (which bundles both the abort trigger and the status-change +// subscription needed for the runtime to react). +// +// When the agent is client-managed (no store configured), neither action +// is registered: there is no server-side snapshot to fetch or abort. +// Surfacing actions only when the underlying capabilities exist keeps the +// reflected API aligned with what the agent can actually do. func registerSnapshotActions[State any]( r api.Registry, agentName string, store SessionStore[State], transform StateTransform[State], ) { + if store == nil { + return + } core.DefineAction(r, agentName, api.ActionTypeAgentSnapshot, nil, nil, func(ctx context.Context, req *GetSnapshotRequest) (*GetSnapshotResponse[State], error) { - if store == nil { - return nil, core.NewError(core.FAILED_PRECONDITION, - "getSnapshot: agent %q has no session store configured", agentName) - } if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId is required") } @@ -446,14 +408,14 @@ func registerSnapshotActions[State any]( if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "abortSnapshot: snapshotId is required") } - meta, err := aborter.AbortSnapshot(ctx, req.SnapshotID) + status, err := aborter.AbortSnapshot(ctx, req.SnapshotID) if err != nil { return nil, core.NewError(core.INTERNAL, "abortSnapshot: %v", err) } - if meta == nil { + if status == "" { return nil, core.NewError(core.NOT_FOUND, "abortSnapshot: snapshot %q not found", req.SnapshotID) } - return &AbortSnapshotResponse{SnapshotID: meta.SnapshotID, Status: meta.Status}, nil + return &AbortSnapshotResponse{SnapshotID: req.SnapshotID, Status: status}, nil }) } diff --git a/go/core/schemas.config b/go/core/schemas.config index 97fa9d7252..b55332ad0d 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1262,17 +1262,26 @@ AgentInit typeparams [State any] AgentInit doc AgentInit is the input for starting an agent invocation. -Provide either SnapshotID (to load from store) or State (direct state). +Exactly one of SnapshotID or State may be set, and the choice must match +the agent's state management: + - Server-managed state (a session store is configured): callers must + use SnapshotID; sending State is rejected. + - Client-managed state (no session store): callers must use State; + sending SnapshotID is rejected. +Sending both fields is always rejected. Sending neither starts a fresh +invocation with empty state. . AgentInit.snapshotId doc -SnapshotID loads state from a persisted snapshot. -Mutually exclusive with State. +SnapshotID loads state from a persisted snapshot. Only valid when the +agent is server-managed (a session store is configured). Mutually +exclusive with State. . AgentInit.state doc -State provides direct state for the invocation. -Mutually exclusive with SnapshotID. +State provides direct state for the invocation. Only valid when the +agent is client-managed (no session store). Mutually exclusive with +SnapshotID. . # ---------------------------------------------------------------------------- @@ -1406,6 +1415,68 @@ SessionState.artifacts doc Artifacts are named collections of parts produced during the conversation. . +# ---------------------------------------------------------------------------- +# SessionSnapshot +# ---------------------------------------------------------------------------- + +SessionSnapshot pkg ai/exp +SessionSnapshot typeparams [State any] + +SessionSnapshot doc +SessionSnapshot is a persisted point-in-time capture of session state. It +is the canonical record written to and read from a [SessionStore]. +. + +SessionSnapshot.snapshotId noomitempty +SessionSnapshot.snapshotId doc +SnapshotID is the unique identifier for this snapshot (UUID). +. + +SessionSnapshot.parentId doc +ParentID is the ID of the previous snapshot in this timeline. +. + +SessionSnapshot.createdAt type time.Time +SessionSnapshot.createdAt noomitempty +SessionSnapshot.createdAt doc +CreatedAt is when the snapshot was created. +. + +SessionSnapshot.updatedAt type time.Time +SessionSnapshot.updatedAt doc +UpdatedAt is when the snapshot was last written. For pending snapshots +it equals CreatedAt; once the snapshot is finalized it reflects the +terminal write. +. + +SessionSnapshot.event noomitempty +SessionSnapshot.event doc +Event is what triggered this snapshot. +. + +SessionSnapshot.status doc +Status is the lifecycle state of this snapshot. Empty is treated as +[SnapshotStatusSucceeded] for backwards compatibility. +. + +SessionSnapshot.error type *core.GenkitError +SessionSnapshot.error doc +Error is the structured failure information for a snapshot in +[SnapshotStatusFailed]. Nil otherwise. +. + +# The synthesized SessionSnapshotError type mirrors the GenkitError wire +# shape inline in the JSON schema. SessionSnapshot.error is overridden to +# *core.GenkitError above, so the synthesized stub is unused. +SessionSnapshotError omit + +SessionSnapshot.state doc +State is the conversation state captured at this point. Nil on a +pending snapshot (the live state is not yet committed; the background +invocation is still processing queued inputs); populated on terminal +snapshots with the cumulative final state. +. + # ---------------------------------------------------------------------------- # SnapshotEvent # ---------------------------------------------------------------------------- @@ -1474,51 +1545,6 @@ The snapshot's Error field describes the failure and resume is rejected with that same error. . -# SnapshotMetadata -SnapshotMetadata pkg ai/exp - -SnapshotMetadata doc -SnapshotMetadata is the metadata-only projection of a [SessionSnapshot]: -identifying fields, lifecycle timestamps, and status. Returned by store -operations that surface a snapshot's lifecycle state without paying for -a full state read. -. - -SnapshotMetadata.snapshotId noomitempty -SnapshotMetadata.snapshotId doc -SnapshotID is the unique identifier for this snapshot. -. - -SnapshotMetadata.parentId doc -ParentID is the ID of the previous snapshot in this timeline. -. - -SnapshotMetadata.createdAt type time.Time -SnapshotMetadata.createdAt noomitempty -SnapshotMetadata.createdAt doc -CreatedAt is when the snapshot was first written. -. - -SnapshotMetadata.updatedAt type time.Time -SnapshotMetadata.updatedAt doc -UpdatedAt is when the snapshot was last written. -. - -SnapshotMetadata.event noomitempty -SnapshotMetadata.event doc -Event is what triggered this snapshot. -. - -SnapshotMetadata.status doc -Status is the lifecycle state of this snapshot. -. - -SnapshotMetadata.error type *core.GenkitError -SnapshotMetadata.error doc -Error is the structured failure information for a snapshot in -[SnapshotStatusFailed]. -. - # GetSnapshotRequest GetSnapshotRequest pkg ai/exp diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 1aec9195fe..0c233aed7d 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -199,17 +199,8 @@ class GetSnapshotResponse(GenkitModel): state: SessionState | None = None -class SessionState(GenkitModel): - """Model for sessionstate data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - messages: list[MessageData] | None = None - custom: Any | None = Field(default=None) - artifacts: list[Artifact] | None = None - - -class SnapshotMetadata(GenkitModel): - """Model for snapshotmetadata data.""" +class SessionSnapshot(GenkitModel): + """Model for sessionsnapshot data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) snapshot_id: str = Field(...) @@ -218,7 +209,17 @@ class SnapshotMetadata(GenkitModel): updated_at: str | None = None event: SnapshotEvent = Field(...) status: SnapshotStatus | None = None - error: Any | None = Field(default=None) + error: Error | None = None + state: SessionState | None = None + + +class SessionState(GenkitModel): + """Model for sessionstate data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + messages: list[MessageData] | None = None + custom: Any | None = Field(default=None) + artifacts: list[Artifact] | None = None class TurnEnd(GenkitModel): @@ -935,6 +936,15 @@ class Resume(GenkitModel): metadata: Metadata | None = None +class Error(GenkitModel): + """Model for error data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) + status: str = Field(...) + message: str = Field(...) + details: Any | None = Field(default=None) + + class Details(GenkitModel): """Model for details data.""" @@ -974,13 +984,6 @@ class Supports(GenkitModel): long_running: bool | None = None -class Error(GenkitModel): - """Model for error data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) - message: str = Field(...) - - class Resource(GenkitModel): """Model for resource data.""" From a02e991d32e5f660caa4a6382d83c612cb4edfb2 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 07:52:24 -0700 Subject: [PATCH 068/141] fix(schema): make status optional on snapshot error to unblock Pydantic Error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The inline SessionSnapshot.error shape (status, message, details) collides with OperationSchema.error's inline object in Python codegen — both synthesize to the same `Error` Pydantic class, and the codegen picks the superset. Marking status required broke veo.py, which constructs Error from a third-party API response that has no canonical status. Go side still gets *core.GenkitError via the schemas.config override, so the Genkit wire format is unaffected. The user-facing schema will be named and tightened in a follow-up; relaxing status now matches the real-world construction sites until then. --- genkit-tools/common/src/types/agent.ts | 2 +- genkit-tools/genkit-schema.json | 1 - py/packages/genkit/src/genkit/_core/_typing.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 3361d94566..8753873815 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -208,7 +208,7 @@ export const SessionSnapshotSchema = z.object({ error: z .object({ /** Canonical status name (e.g. `INTERNAL`, `FAILED_PRECONDITION`). */ - status: z.string(), + status: z.string().optional(), /** Human-readable error message. */ message: z.string(), /** Optional structured details describing the failure. */ diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 1d41d390a7..fda6d1eabc 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -240,7 +240,6 @@ "details": {} }, "required": [ - "status", "message" ], "additionalProperties": false diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 0c233aed7d..c652a5f563 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -940,7 +940,7 @@ class Error(GenkitModel): """Model for error data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) - status: str = Field(...) + status: str | None = None message: str = Field(...) details: Any | None = Field(default=None) From 12e3990698eaaca7d52df9e4df54737d346810f9 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 08:14:16 -0700 Subject: [PATCH 069/141] refactor(go/exp): drop unused PromptIn type param from DefinePromptAgent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PromptIn provided no real type safety: ai.Prompt.Render takes `any`, so the constrained defaultInput erased to any immediately and the prompt did its own schema validation at runtime. Every call site wrote `DefinePromptAgent[State, any]` and PromptIn was pure noise. Replace it with a Define-time smoke render so a defaultInput that fails the prompt's input schema panics here rather than failing on the first invocation. Same safety guarantee at a meaningful boundary, simpler API. Call sites lose one type argument: - DefinePromptAgent[testState, any](...) → DefinePromptAgent[testState](...) - DefinePromptAgent[any](...) → DefinePromptAgent(...) (State inferred) --- go/ai/exp/agent.go | 14 +++++++++----- go/ai/exp/agent_test.go | 10 +++++----- go/genkit/genkit.go | 11 ++++++----- go/samples/prompt-agent/main.go | 2 +- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index f067d3b4ef..7b69e0aaba 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -287,22 +287,26 @@ func DefineAgent[State any]( // The agent is registered under the same name as the prompt, sharing its // namespace. // -// defaultInput is used to render the prompt on every turn. PromptIn is -// captured for compile-time type checking on defaultInput; it is not -// propagated through the [Agent] type. +// defaultInput is used to render the prompt on every turn. DefinePromptAgent +// invokes the prompt's Render once at definition time as a smoke check, so +// a defaultInput that fails the prompt's input schema panics here rather +// than failing on the first invocation. // // For an agent that defines its prompt inline, use [DefineAgent]. For full // control over the per-turn loop, use [DefineCustomAgent]. -func DefinePromptAgent[State, PromptIn any]( +func DefinePromptAgent[State any]( r api.Registry, promptName string, - defaultInput PromptIn, + defaultInput any, opts ...AgentOption[State], ) *Agent[any, State] { prompt := ai.LookupPrompt(r, promptName) if prompt == nil { panic(fmt.Sprintf("DefinePromptAgent: prompt %q not found", promptName)) } + if _, err := prompt.Render(context.Background(), defaultInput); err != nil { + panic(fmt.Sprintf("DefinePromptAgent %q: defaultInput does not satisfy prompt schema: %v", promptName, err)) + } return DefineCustomAgent(r, promptName, agentLoop[State](r, prompt, defaultInput), opts...) } diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 665297c090..99f41b6d0c 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -950,7 +950,7 @@ func TestPromptAgent_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState, any](reg, "testPrompt", nil) + af := DefinePromptAgent[testState](reg, "testPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1037,7 +1037,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - af := DefinePromptAgent[testState, any](reg, "historyPrompt", nil) + af := DefinePromptAgent[testState](reg, "historyPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1111,7 +1111,7 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState, any](reg, "snapPrompt", nil, + af := DefinePromptAgent[testState](reg, "snapPrompt", nil, WithSessionStore(store), ) @@ -1252,7 +1252,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ai.WithTools(ai.ToolName("greet"), ai.ToolName("farewell")), ) - af := DefinePromptAgent[testState, any](reg, "toolPrompt", nil) + af := DefinePromptAgent[testState](reg, "toolPrompt", nil) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1491,7 +1491,7 @@ func TestPromptAgent_RunText(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState, any](reg, "runTextPrompt", nil) + af := DefinePromptAgent[testState](reg, "runTextPrompt", nil) response, err := af.RunText(ctx, "hello") if err != nil { diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index bc002bf410..60dc9d7251 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -494,15 +494,16 @@ func DefineAgent[State any]( // Experimental: This API is under active development and may change in any // minor version release. // -// defaultInput is used to render the prompt on every turn. PromptIn is -// captured for compile-time type checking on defaultInput. +// defaultInput is used to render the prompt on every turn. The prompt's +// Render is invoked once at definition time as a smoke check, so a +// defaultInput that fails the prompt's input schema panics here rather +// than failing on the first invocation. // // For an agent that defines its prompt inline, use [DefineAgent]. For full // control over the per-turn loop, use [DefineCustomAgent]. // // Type parameters: // - State: Type for user-defined state persisted in snapshots -// - PromptIn: The prompt input type (inferred from defaultInput) // // Example: // @@ -514,10 +515,10 @@ func DefineAgent[State any]( // ChatInput{Personality: "a sarcastic pirate"}, // aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), // ) -func DefinePromptAgent[State, PromptIn any]( +func DefinePromptAgent[State any]( g *Genkit, promptName string, - defaultInput PromptIn, + defaultInput any, opts ...aix.AgentOption[State], ) *aix.Agent[any, State] { return aix.DefinePromptAgent(g.reg, promptName, defaultInput, opts...) diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index 6d0e2f8697..24b291567d 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -39,7 +39,7 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatAgent := genkit.DefinePromptAgent[any](g, "chat", + chatAgent := genkit.DefinePromptAgent(g, "chat", ChatPromptInput{Personality: "a sarcastic pirate"}, aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { From 081405d8c8d92ce602be91281e217f194c7c898c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 08:33:23 -0700 Subject: [PATCH 070/141] refactor(go/exp): rename samples, add agent-inline, fix AgentDefineOption marker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sample reorg: - custom-agent → agent-custom; prompt-agent → agent-prompt. - New agent-inline sample uses DefineAgent with an inline prompt (mixed variadic of ai prompt options + aix agent options). Marker fix: the documented "DefineAgent accepts ai.PromptOption + aix.AgentOption" mixed variadic never compiled. The marker method isAgentDefineOption() was unexported and declared separately in both ai and ai/exp packages — Go's package-scoping rule treated them as different methods, so ai option types never actually satisfied exp.AgentDefineOption[State]. The bug went unnoticed because no test ever called DefineAgent with an ai.WithX option (the only call site passed aix.WithSessionStore). Fix: introduce ai.AgentDefineOption (exported marker, package ai) embedded in each prompt-definition option interface (ConfigOption, CommonGenOption, InputOption, PromptOption, PromptingOption, OutputOption). exp.AgentDefineOption[State] now embeds ai.AgentDefineOption; aix.AgentOption[State] inherits the marker through the chain. Method renamed isAgentDefineOption → IsAgentDefineOption on every implementing struct (ai prompt option structs and exp.agentOptions[State]). --- go/ai/exp/option.go | 8 +- go/ai/option.go | 33 ++++-- .../{custom-agent => agent-custom}/main.go | 9 +- go/samples/agent-inline/main.go | 101 ++++++++++++++++++ .../{prompt-agent => agent-prompt}/main.go | 5 +- .../prompts/chat.prompt | 0 6 files changed, 139 insertions(+), 17 deletions(-) rename go/samples/{custom-agent => agent-custom}/main.go (87%) create mode 100644 go/samples/agent-inline/main.go rename go/samples/{prompt-agent => agent-prompt}/main.go (92%) rename go/samples/{prompt-agent => agent-prompt}/prompts/chat.prompt (100%) diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 69c9fc7e65..a360949f79 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -19,6 +19,8 @@ package exp import ( "context" "errors" + + "github.com/firebase/genkit/go/ai" ) // --- AgentDefineOption --- @@ -26,14 +28,14 @@ import ( // AgentDefineOption is the marker interface for any option that can be passed // to [DefineAgent]. It is satisfied by every [github.com/firebase/genkit/go/ai.PromptOption] // (which configures the underlying prompt) and by every [AgentOption] (which -// configures the agent itself). +// configures the agent itself), both of which embed [ai.AgentDefineOption]. // // The State type parameter is phantom: a single concrete option satisfies // [AgentDefineOption] for any State, so type inference cannot pick a State // from the variadic. Callers of [DefineAgent] must specify [State] explicitly // (use [any] when no typed Custom state is needed). type AgentDefineOption[State any] interface { - isAgentDefineOption() + ai.AgentDefineOption } // --- AgentOption --- @@ -66,7 +68,7 @@ type agentOptions[State any] struct { transform StateTransform[State] } -func (*agentOptions[State]) isAgentDefineOption() {} +func (*agentOptions[State]) IsAgentDefineOption() {} func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { if o.store != nil { diff --git a/go/ai/option.go b/go/ai/option.go index fce215d45a..25afad78b3 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -31,15 +31,24 @@ type PromptFn = func(context.Context, any) (string, error) // MessagesFn is a function that generates messages. type MessagesFn = func(context.Context, any) ([]*Message, error) -// No-op marker so the exp package's AgentDefineOption marker interface is -// satisfied by every PromptOption. Lets DefineAgent accept a mixed variadic -// of prompt options and agent-only options. -func (*configOptions) isAgentDefineOption() {} -func (*commonGenOptions) isAgentDefineOption() {} -func (*inputOptions) isAgentDefineOption() {} -func (*promptOptions) isAgentDefineOption() {} -func (*promptingOptions) isAgentDefineOption() {} -func (*outputOptions) isAgentDefineOption() {} +// AgentDefineOption is the marker satisfied by every option that can be +// passed to [github.com/firebase/genkit/go/ai/exp.DefineAgent]. Embedded +// in each prompt-definition option interface in this package and in +// [github.com/firebase/genkit/go/ai/exp.AgentOption] so DefineAgent can +// accept a mixed variadic of prompt options and agent-only options +// through a single static type. +// +// IsAgentDefineOption is an inert marker; callers never invoke it. +type AgentDefineOption interface { + IsAgentDefineOption() +} + +func (*configOptions) IsAgentDefineOption() {} +func (*commonGenOptions) IsAgentDefineOption() {} +func (*inputOptions) IsAgentDefineOption() {} +func (*promptOptions) IsAgentDefineOption() {} +func (*promptingOptions) IsAgentDefineOption() {} +func (*outputOptions) IsAgentDefineOption() {} // configOptions holds configuration options. type configOptions struct { @@ -48,6 +57,7 @@ type configOptions struct { // ConfigOption is an option for model configuration. type ConfigOption interface { + AgentDefineOption applyConfig(*configOptions) error applyCommonGen(*commonGenOptions) error applyPrompt(*promptOptions) error @@ -124,6 +134,7 @@ type commonGenOptions struct { } type CommonGenOption interface { + AgentDefineOption applyCommonGen(*commonGenOptions) error applyPrompt(*promptOptions) error applyGenerate(*generateOptions) error @@ -305,6 +316,7 @@ type inputOptions struct { // InputOption is an option for the input of a prompt. // It applies only to DefinePrompt(). type InputOption interface { + AgentDefineOption applyInput(*inputOptions) error applyPrompt(*promptOptions) error applyTool(*toolOptions) error @@ -390,6 +402,7 @@ type promptOptions struct { // PromptOption is an option for defining a prompt. // It applies only to DefinePrompt(). type PromptOption interface { + AgentDefineOption applyPrompt(*promptOptions) error } @@ -447,6 +460,7 @@ type promptingOptions struct { // PromptingOption is an option for the system and user prompts of a prompt or generate request. // It applies only to DefinePrompt() and Generate(). type PromptingOption interface { + AgentDefineOption applyPrompting(*promptingOptions) error applyPrompt(*promptOptions) error applyGenerate(*generateOptions) error @@ -528,6 +542,7 @@ type outputOptions struct { // OutputOption is an option for the output of a prompt or generate request. // It applies only to DefinePrompt() and Generate(). type OutputOption interface { + AgentDefineOption applyOutput(*outputOptions) error applyPrompt(*promptOptions) error applyGenerate(*generateOptions) error diff --git a/go/samples/custom-agent/main.go b/go/samples/agent-custom/main.go similarity index 87% rename from go/samples/custom-agent/main.go rename to go/samples/agent-custom/main.go index 55668cc277..9b0d10ea94 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/agent-custom/main.go @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates the custom Agent API for multi-turn conversation -// with token-level streaming. It runs a CLI REPL where conversation history -// is managed automatically by the session. +// This sample demonstrates DefineCustomAgent, which gives the caller full +// control over the per-turn loop (model selection, history management, +// streaming chunks). It runs a CLI REPL where conversation history is +// maintained by the session. Compare with agent-prompt (DefinePromptAgent) +// and agent-inline (DefineAgent), which both auto-wire the loop and +// differ only in where the prompt is defined. package main import ( diff --git a/go/samples/agent-inline/main.go b/go/samples/agent-inline/main.go new file mode 100644 index 0000000000..d33845a4b1 --- /dev/null +++ b/go/samples/agent-inline/main.go @@ -0,0 +1,101 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates DefineAgent, which creates a multi-turn +// conversational agent backed by a prompt defined inline alongside the +// agent. The conversation loop (render prompt, call model, stream chunks, +// update history) is handled automatically. Compare with agent-prompt +// (DefinePromptAgent), which sources its prompt from a .prompt file, and +// agent-custom (DefineCustomAgent), which wires the same loop manually. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + chatAgent := genkit.DefineAgent[any](g, "chat", + ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), + aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ) + + fmt.Println("Agent Chat (type 'quit' to exit)") + fmt.Println() + + conn, err := chatAgent.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) + } + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) + } + fmt.Println() + fmt.Println() + break + } + } + } + + conn.Close() +} diff --git a/go/samples/prompt-agent/main.go b/go/samples/agent-prompt/main.go similarity index 92% rename from go/samples/prompt-agent/main.go rename to go/samples/agent-prompt/main.go index 24b291567d..2a001df874 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/agent-prompt/main.go @@ -15,8 +15,9 @@ // This sample demonstrates DefinePromptAgent, which creates a multi-turn // conversational agent backed by a .prompt file. The conversation loop // (render prompt, call model, stream chunks, update history) is handled -// automatically. Compare with custom-agent which wires the same loop -// manually. +// automatically. Compare with agent-custom (DefineCustomAgent), which +// wires the same loop manually, and agent-inline (DefineAgent), which +// defines the prompt inline alongside the agent. package main import ( diff --git a/go/samples/prompt-agent/prompts/chat.prompt b/go/samples/agent-prompt/prompts/chat.prompt similarity index 100% rename from go/samples/prompt-agent/prompts/chat.prompt rename to go/samples/agent-prompt/prompts/chat.prompt From 9b290ea3c33c33a5ec86dd5bafcb66edcdc623b5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 08:37:45 -0700 Subject: [PATCH 071/141] test(go/exp): cover State mismatch for every typed AgentOption variant DefineAgent's mixed variadic accepts AgentDefineOption[State] where State is phantom on the marker, so a typed AgentOption[Wrong] passed to DefineAgent[Right] satisfies the variadic at compile time and is only caught at runtime via type-assertion-then-panic. Existing test covered WithSessionStore; the same runtime path applies to WithSnapshotCallback, WithSnapshotOn, and WithStateTransform. Table-drive the test across all four so the runtime backstop is exercised wherever it matters. DefineCustomAgent and DefinePromptAgent both declare opts ...AgentOption[State] directly, so they reject State mismatches at compile time and need no runtime coverage. --- go/ai/exp/agent_test.go | 61 +++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 99f41b6d0c..092616c855 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -1764,26 +1764,57 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { conn.Close() } -// TestAgent_DefineAgent_StateMismatchPanics verifies that passing a typed -// AgentOption (e.g., a session store) with a different State than the -// declared one on DefineAgent panics with a clear message. +// TestAgent_DefineAgent_StateMismatchPanics verifies that every typed +// AgentOption variant panics at definition time when its State type +// parameter does not match the State declared on DefineAgent. This is +// the runtime backstop for DefineAgent's mixed variadic (which cannot +// enforce State at compile time because AgentDefineOption's State is +// phantom). DefineCustomAgent and DefinePromptAgent both declare +// `opts ...AgentOption[State]`, so they catch the same mismatch at +// compile time and need no runtime test. func TestAgent_DefineAgent_StateMismatchPanics(t *testing.T) { type otherState struct{ X int } reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() - defer func() { - r := recover() - if r == nil { - t.Fatal("expected panic on State mismatch") - } - msg := fmt.Sprintf("%v", r) - if !strings.Contains(msg, "does not match agent State") { - t.Errorf("panic message missing expected substring: %s", msg) - } - }() + cases := []struct { + name string + opt AgentDefineOption[otherState] + }{ + { + name: "WithSessionStore", + opt: WithSessionStore[testState](NewInMemorySessionStore[testState]()), + }, + { + name: "WithSnapshotCallback", + opt: WithSnapshotCallback[testState](func(context.Context, *SnapshotContext[testState]) bool { + return true + }), + }, + { + name: "WithSnapshotOn", + opt: WithSnapshotOn[testState](SnapshotEventTurnEnd), + }, + { + name: "WithStateTransform", + opt: WithStateTransform[testState](func(_ context.Context, s SessionState[testState]) SessionState[testState] { return s }), + }, + } - _ = DefineAgent[otherState](reg, "mismatch", WithSessionStore(store)) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected panic on State mismatch via %s", tc.name) + } + msg := fmt.Sprintf("%v", r) + if !strings.Contains(msg, "does not match agent State") { + t.Errorf("panic message missing expected substring: %s", msg) + } + }() + _ = DefineAgent[otherState](reg, "mismatch-"+tc.name, tc.opt) + }) + } } // --- Detach, transform, and getSnapshot tests --- From 335c6df09ccddda32a1198a7b40c588993152bb8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 12:09:15 -0700 Subject: [PATCH 072/141] refactor(go/exp): switch StateTransform to pointer signature StateTransform now takes and returns *SessionState[State], matching the rest of the API (Session.State, SnapshotContext, AgentOutput, and SessionSnapshot all use the pointer). This enables nil-return for "omit state from the response", drops a redundant struct copy in applyTransform, and lets the doc describe ownership plainly. Also drops a now-unnecessary explicit [State] arg on WithSnapshotCallback inside WithSnapshotOn, since the callback parameter pins down State on its own. --- go/ai/exp/agent_test.go | 6 +++--- go/ai/exp/option.go | 10 ++++++---- go/ai/exp/session.go | 5 ++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 092616c855..87e39fafb0 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -1796,7 +1796,7 @@ func TestAgent_DefineAgent_StateMismatchPanics(t *testing.T) { }, { name: "WithStateTransform", - opt: WithStateTransform[testState](func(_ context.Context, s SessionState[testState]) SessionState[testState] { return s }), + opt: WithStateTransform[testState](func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { return s }), }, } @@ -2585,7 +2585,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { store := NewInMemorySessionStore[testState]() // Transform that scrubs a specific word from all messages. - transform := func(_ context.Context, s SessionState[testState]) SessionState[testState] { + transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { for _, msg := range s.Messages { for _, p := range msg.Content { if p.Text != "" { @@ -2944,7 +2944,7 @@ func TestAgent_StateTransform_ClientManagedState(t *testing.T) { reg := newTestRegistry(t) // Client-managed state: transform should be applied to AgentOutput.State. - transform := func(_ context.Context, s SessionState[testState]) SessionState[testState] { + transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { // Zero out the counter to demonstrate the transform is applied. s.Custom.Counter = -1 return s diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index a360949f79..d090c4cd49 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -58,9 +58,11 @@ type AgentOption[State any] interface { // and context-scoped values (e.g. the caller's identity for RBAC-aware // redaction) flow through here. // -// The state input is a deep copy owned by the caller; the transform -// may mutate and return it, or return a freshly-constructed value. -type StateTransform[State any] = func(ctx context.Context, state SessionState[State]) SessionState[State] +// state is a fresh deep copy made for this call: the transform owns it +// and may mutate in place, return a new pointer, or return nil to omit +// state from the response entirely. Do not retain the pointer past the +// call; the framework drops its reference after the transform returns. +type StateTransform[State any] = func(ctx context.Context, state *SessionState[State]) *SessionState[State] type agentOptions[State any] struct { store SessionStore[State] @@ -114,7 +116,7 @@ func WithSnapshotOn[State any](events ...SnapshotEvent) AgentOption[State] { for _, e := range events { set[e] = struct{}{} } - return WithSnapshotCallback[State](func(_ context.Context, sc *SnapshotContext[State]) bool { + return WithSnapshotCallback(func(_ context.Context, sc *SnapshotContext[State]) bool { _, ok := set[sc.Event] return ok }) diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 72362ad50c..6b1165dc68 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -49,14 +49,13 @@ type SnapshotContext[State any] struct { // If not provided and a store is configured, snapshots are always created. type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool -// applyTransform returns the result of applying t to *state, or state +// applyTransform returns the result of applying t to state, or state // unchanged if t is nil. A nil state is returned as-is. func applyTransform[State any](ctx context.Context, t StateTransform[State], state *SessionState[State]) *SessionState[State] { if t == nil || state == nil { return state } - transformed := t(ctx, *state) - return &transformed + return t(ctx, state) } // --- Session store --- From 56c8114f851d964f4ac9d4481be5b1aa211241c5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 14:27:40 -0700 Subject: [PATCH 073/141] refactor(go/exp): unify DefineAgent with FromInline/FromPrompt sources DefineAgent now takes a typed AgentSource as its third arg (either FromInline for inline prompt options or FromPrompt for an existing prompt looked up by the agent's name) followed by a typed variadic of AgentOption[State]. This: * fixes the phantom-State hole: a State mismatch on any agent option now fails at compile time instead of panicking at registration; * subsumes DefinePromptAgent, which is deleted; * lets State be inferred from the typed agent options in the common case, so the explicit [State] type arg can be dropped from DefineAgent calls that pass a typed option; * lets the agent-vs-prompt namespace invariant (they share a name) become a structural property of the API rather than a convention. The mixed-variadic AgentDefineOption marker on every ai.PromptOption interface is now unused and removed (both the ai package marker and its exp package generic wrapper). Migrate the three samples and five DefinePromptAgent test call sites. The State-mismatch panic test is obsolete (the mismatch is a compile error now) and is deleted. --- go/ai/exp/agent.go | 85 +++++++++++------------------- go/ai/exp/agent_test.go | 63 ++-------------------- go/ai/exp/option.go | 26 ++-------- go/ai/exp/source.go | 71 +++++++++++++++++++++++++ go/ai/option.go | 25 --------- go/genkit/genkit.go | 92 +++++++++++---------------------- go/samples/agent-custom/main.go | 13 ++--- go/samples/agent-inline/main.go | 29 ++++++----- go/samples/agent-prompt/main.go | 17 +++--- 9 files changed, 171 insertions(+), 250 deletions(-) create mode 100644 go/ai/exp/source.go diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 7b69e0aaba..1bd008e671 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -243,71 +243,46 @@ type Agent[Stream, State any] struct { action *core.Action[*AgentInit[State], *AgentOutput[State], *AgentStreamChunk[Stream], *AgentInput] } -// DefineAgent defines an agent that wraps a prompt defined inline from the -// given options, and registers both under name. Each turn renders the prompt, -// appends conversation history, calls the model with streaming, and updates -// session state. +// DefineAgent defines a prompt-backed agent and registers it. Each turn +// renders the agent's prompt, appends conversation history, calls the +// model with streaming, and updates session state. // -// opts is a mixed list of [github.com/firebase/genkit/go/ai.PromptOption] -// values (which configure the prompt) and [AgentOption] values (which -// configure the agent itself, e.g., [WithSessionStore]). +// source selects how the prompt is backed: // -// State is phantom in the variadic, so it cannot be inferred. Specify [any] -// when no typed Custom state is needed; specify [Foo] when a -// [SessionStore[Foo]] is provided. A mismatch panics at definition time with -// a clear message. +// - [FromInline] defines the prompt inline from a set of +// [ai.PromptOption] values; the prompt is registered under name. +// - [FromPrompt] references an existing prompt registered with the +// registry under name (e.g. one defined via [ai.DefinePrompt] or +// loaded from a .prompt file). // -// For an agent backed by an existing prompt, use [DefinePromptAgent]. For -// full control over the per-turn loop, use [DefineCustomAgent]. +// State is inferred from the typed agent options (e.g. +// [WithSessionStore], [WithSnapshotOn]); pass an explicit [State] only +// when no typed option is provided. A typed option that disagrees with +// the inferred State fails at compile time. +// +// For full control over the per-turn loop, use [DefineCustomAgent]. func DefineAgent[State any]( r api.Registry, name string, - opts ...AgentDefineOption[State], + source AgentSource, + opts ...AgentOption[State], ) *Agent[any, State] { - var promptOpts []ai.PromptOption - var agentOpts []AgentOption[State] - for _, opt := range opts { - if ao, ok := opt.(AgentOption[State]); ok { - agentOpts = append(agentOpts, ao) - continue + switch s := source.(type) { + case inlineSource: + prompt := ai.DefinePrompt(r, name, s.opts...) + return DefineCustomAgent(r, name, agentLoop[State](r, prompt, nil), opts...) + case promptSource: + prompt := ai.LookupPrompt(r, name) + if prompt == nil { + panic(fmt.Sprintf("DefineAgent %q: prompt %q not found", name, name)) } - if po, ok := opt.(ai.PromptOption); ok { - promptOpts = append(promptOpts, po) - continue + if _, err := prompt.Render(context.Background(), s.defaultInput); err != nil { + panic(fmt.Sprintf("DefineAgent %q: defaultInput does not satisfy prompt schema: %v", name, err)) } - panic(fmt.Sprintf("DefineAgent %q: option of type %T does not match agent State %T (likely a typed AgentOption with a different State than the one declared on DefineAgent)", name, opt, *new(State))) - } - - prompt := ai.DefinePrompt(r, name, promptOpts...) - return DefineCustomAgent(r, name, agentLoop[State](r, prompt, nil), agentOpts...) -} - -// DefinePromptAgent defines an agent backed by a prompt already registered -// with the registry (via [ai.DefinePrompt] or loaded from a .prompt file). -// The agent is registered under the same name as the prompt, sharing its -// namespace. -// -// defaultInput is used to render the prompt on every turn. DefinePromptAgent -// invokes the prompt's Render once at definition time as a smoke check, so -// a defaultInput that fails the prompt's input schema panics here rather -// than failing on the first invocation. -// -// For an agent that defines its prompt inline, use [DefineAgent]. For full -// control over the per-turn loop, use [DefineCustomAgent]. -func DefinePromptAgent[State any]( - r api.Registry, - promptName string, - defaultInput any, - opts ...AgentOption[State], -) *Agent[any, State] { - prompt := ai.LookupPrompt(r, promptName) - if prompt == nil { - panic(fmt.Sprintf("DefinePromptAgent: prompt %q not found", promptName)) - } - if _, err := prompt.Render(context.Background(), defaultInput); err != nil { - panic(fmt.Sprintf("DefinePromptAgent %q: defaultInput does not satisfy prompt schema: %v", promptName, err)) + return DefineCustomAgent(r, name, agentLoop[State](r, prompt, s.defaultInput), opts...) + default: + panic(fmt.Sprintf("DefineAgent %q: unknown source type %T", name, source)) } - return DefineCustomAgent(r, promptName, agentLoop[State](r, prompt, defaultInput), opts...) } // DefineCustomAgent defines an agent with full control over the conversation diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 87e39fafb0..2712f572b4 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -950,7 +950,7 @@ func TestPromptAgent_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState](reg, "testPrompt", nil) + af := DefineAgent[testState](reg, "testPrompt", FromPrompt()) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1037,7 +1037,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - af := DefinePromptAgent[testState](reg, "historyPrompt", nil) + af := DefineAgent[testState](reg, "historyPrompt", FromPrompt()) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1111,7 +1111,7 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState](reg, "snapPrompt", nil, + af := DefineAgent[testState](reg, "snapPrompt", FromPrompt(), WithSessionStore(store), ) @@ -1252,7 +1252,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ai.WithTools(ai.ToolName("greet"), ai.ToolName("farewell")), ) - af := DefinePromptAgent[testState](reg, "toolPrompt", nil) + af := DefineAgent[testState](reg, "toolPrompt", FromPrompt()) conn, err := af.StreamBidi(ctx) if err != nil { @@ -1491,7 +1491,7 @@ func TestPromptAgent_RunText(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefinePromptAgent[testState](reg, "runTextPrompt", nil) + af := DefineAgent[testState](reg, "runTextPrompt", FromPrompt()) response, err := af.RunText(ctx, "hello") if err != nil { @@ -1764,59 +1764,6 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { conn.Close() } -// TestAgent_DefineAgent_StateMismatchPanics verifies that every typed -// AgentOption variant panics at definition time when its State type -// parameter does not match the State declared on DefineAgent. This is -// the runtime backstop for DefineAgent's mixed variadic (which cannot -// enforce State at compile time because AgentDefineOption's State is -// phantom). DefineCustomAgent and DefinePromptAgent both declare -// `opts ...AgentOption[State]`, so they catch the same mismatch at -// compile time and need no runtime test. -func TestAgent_DefineAgent_StateMismatchPanics(t *testing.T) { - type otherState struct{ X int } - reg := newTestRegistry(t) - - cases := []struct { - name string - opt AgentDefineOption[otherState] - }{ - { - name: "WithSessionStore", - opt: WithSessionStore[testState](NewInMemorySessionStore[testState]()), - }, - { - name: "WithSnapshotCallback", - opt: WithSnapshotCallback[testState](func(context.Context, *SnapshotContext[testState]) bool { - return true - }), - }, - { - name: "WithSnapshotOn", - opt: WithSnapshotOn[testState](SnapshotEventTurnEnd), - }, - { - name: "WithStateTransform", - opt: WithStateTransform[testState](func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { return s }), - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - defer func() { - r := recover() - if r == nil { - t.Fatalf("expected panic on State mismatch via %s", tc.name) - } - msg := fmt.Sprintf("%v", r) - if !strings.Contains(msg, "does not match agent State") { - t.Errorf("panic message missing expected substring: %s", msg) - } - }() - _ = DefineAgent[otherState](reg, "mismatch-"+tc.name, tc.opt) - }) - } -} - // --- Detach, transform, and getSnapshot tests --- // waitForSnapshot polls the store for a snapshot matching the predicate, diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index d090c4cd49..1cecbcb39f 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -19,32 +19,14 @@ package exp import ( "context" "errors" - - "github.com/firebase/genkit/go/ai" ) -// --- AgentDefineOption --- - -// AgentDefineOption is the marker interface for any option that can be passed -// to [DefineAgent]. It is satisfied by every [github.com/firebase/genkit/go/ai.PromptOption] -// (which configures the underlying prompt) and by every [AgentOption] (which -// configures the agent itself), both of which embed [ai.AgentDefineOption]. -// -// The State type parameter is phantom: a single concrete option satisfies -// [AgentDefineOption] for any State, so type inference cannot pick a State -// from the variadic. Callers of [DefineAgent] must specify [State] explicitly -// (use [any] when no typed Custom state is needed). -type AgentDefineOption[State any] interface { - ai.AgentDefineOption -} - // --- AgentOption --- -// AgentOption configures an agent at definition time. It also satisfies -// [AgentDefineOption] so it can be passed to [DefineAgent] alongside -// [github.com/firebase/genkit/go/ai.PromptOption] values. +// AgentOption configures an agent at definition time. It is accepted by +// [DefineAgent] and [DefineCustomAgent] as a typed variadic, so a State +// mismatch fails at compile time. type AgentOption[State any] interface { - AgentDefineOption[State] applyAgent(*agentOptions[State]) error } @@ -70,8 +52,6 @@ type agentOptions[State any] struct { transform StateTransform[State] } -func (*agentOptions[State]) IsAgentDefineOption() {} - func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { if o.store != nil { if opts.store != nil { diff --git a/go/ai/exp/source.go b/go/ai/exp/source.go new file mode 100644 index 0000000000..d30475f9ca --- /dev/null +++ b/go/ai/exp/source.go @@ -0,0 +1,71 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import "github.com/firebase/genkit/go/ai" + +// AgentSource selects the prompt backing a prompt-based agent. Pass an +// AgentSource as the third argument to [DefineAgent]. There are two +// forms: +// +// - [FromInline] defines the prompt inline from a set of +// [ai.PromptOption] values; the prompt is registered with the +// registry under the agent's name. +// - [FromPrompt] references an existing prompt registered with the +// registry under the same name as the agent (e.g. one defined via +// [ai.DefinePrompt] or loaded from a .prompt file). +// +// The agent and its backing prompt always share a name; if you need +// the lookup name to differ from the agent name, define a custom agent +// via [DefineCustomAgent] instead. +type AgentSource interface { + isAgentSource() +} + +type inlineSource struct { + opts []ai.PromptOption +} + +func (inlineSource) isAgentSource() {} + +// FromInline defines the agent's prompt inline from the given options. +// The prompt is registered with the registry under the agent's name. +func FromInline(opts ...ai.PromptOption) AgentSource { + return inlineSource{opts: opts} +} + +type promptSource struct { + defaultInput any +} + +func (promptSource) isAgentSource() {} + +// FromPrompt references an existing prompt registered with the +// registry under the same name as the agent (e.g. one defined via +// [ai.DefinePrompt] or loaded from a .prompt file). +// +// defaultInput, if provided, is the input passed to the prompt's +// Render on every turn. Call FromPrompt() with no arguments when the +// prompt takes no input. Only the first argument is used; any +// additional arguments are ignored. +func FromPrompt(defaultInput ...any) AgentSource { + var input any + if len(defaultInput) > 0 { + input = defaultInput[0] + } + return promptSource{defaultInput: input} +} diff --git a/go/ai/option.go b/go/ai/option.go index 25afad78b3..d51ca85a4a 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -31,25 +31,6 @@ type PromptFn = func(context.Context, any) (string, error) // MessagesFn is a function that generates messages. type MessagesFn = func(context.Context, any) ([]*Message, error) -// AgentDefineOption is the marker satisfied by every option that can be -// passed to [github.com/firebase/genkit/go/ai/exp.DefineAgent]. Embedded -// in each prompt-definition option interface in this package and in -// [github.com/firebase/genkit/go/ai/exp.AgentOption] so DefineAgent can -// accept a mixed variadic of prompt options and agent-only options -// through a single static type. -// -// IsAgentDefineOption is an inert marker; callers never invoke it. -type AgentDefineOption interface { - IsAgentDefineOption() -} - -func (*configOptions) IsAgentDefineOption() {} -func (*commonGenOptions) IsAgentDefineOption() {} -func (*inputOptions) IsAgentDefineOption() {} -func (*promptOptions) IsAgentDefineOption() {} -func (*promptingOptions) IsAgentDefineOption() {} -func (*outputOptions) IsAgentDefineOption() {} - // configOptions holds configuration options. type configOptions struct { Config any // Primitive (model, embedder, retriever, etc) configuration. @@ -57,7 +38,6 @@ type configOptions struct { // ConfigOption is an option for model configuration. type ConfigOption interface { - AgentDefineOption applyConfig(*configOptions) error applyCommonGen(*commonGenOptions) error applyPrompt(*promptOptions) error @@ -134,7 +114,6 @@ type commonGenOptions struct { } type CommonGenOption interface { - AgentDefineOption applyCommonGen(*commonGenOptions) error applyPrompt(*promptOptions) error applyGenerate(*generateOptions) error @@ -316,7 +295,6 @@ type inputOptions struct { // InputOption is an option for the input of a prompt. // It applies only to DefinePrompt(). type InputOption interface { - AgentDefineOption applyInput(*inputOptions) error applyPrompt(*promptOptions) error applyTool(*toolOptions) error @@ -402,7 +380,6 @@ type promptOptions struct { // PromptOption is an option for defining a prompt. // It applies only to DefinePrompt(). type PromptOption interface { - AgentDefineOption applyPrompt(*promptOptions) error } @@ -460,7 +437,6 @@ type promptingOptions struct { // PromptingOption is an option for the system and user prompts of a prompt or generate request. // It applies only to DefinePrompt() and Generate(). type PromptingOption interface { - AgentDefineOption applyPrompting(*promptingOptions) error applyPrompt(*promptOptions) error applyGenerate(*generateOptions) error @@ -542,7 +518,6 @@ type outputOptions struct { // OutputOption is an option for the output of a prompt or generate request. // It applies only to DefinePrompt() and Generate(). type OutputOption interface { - AgentDefineOption applyOutput(*outputOptions) error applyPrompt(*promptOptions) error applyGenerate(*generateOptions) error diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 60dc9d7251..e4c30981b9 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -430,9 +430,8 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn return core.DefineBidiFlow(g.reg, name, fn) } -// DefineAgent defines an agent that wraps a prompt defined inline from the -// given options, registers both under name as actions on the registry, and -// returns an [aix.Agent]. +// DefineAgent defines a prompt-backed agent and registers it as an +// action on the registry. Returns an [aix.Agent]. // // Experimental: This API is under active development and may change in any // minor version release. @@ -443,85 +442,54 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn // handles session state, conversation history, and optional snapshot // persistence automatically. // -// opts is a mixed list of [ai.PromptOption] values (which configure the -// underlying prompt) and [aix.AgentOption] values (which configure the agent -// itself). The State type parameter must be specified explicitly: use [any] -// when no typed Custom state is needed; use [Foo] when an -// [aix.SessionStore[Foo]] is provided. Mismatches panic at definition time. +// source selects how the prompt is backed: // -// For an agent backed by an existing prompt, use [DefinePromptAgent]. For -// full control over the per-turn loop, use [DefineCustomAgent]. +// - [aix.FromInline] defines the prompt inline from a set of +// [ai.PromptOption] values; the prompt is registered under name. +// - [aix.FromPrompt] references an existing prompt registered with +// the registry under name (e.g. one defined via [DefinePrompt] or +// loaded from a .prompt file). +// +// The State type parameter is inferred from the typed agent options +// (e.g. [aix.WithSessionStore], [aix.WithSnapshotOn]); pass an explicit +// [State] only when no typed option is provided. +// +// For full control over the per-turn loop, use [DefineCustomAgent]. // // # Options // -// - any [ai.PromptOption]: e.g., [ai.WithModel], [ai.WithSystem], [ai.WithTools] // - [aix.WithSessionStore]: Enable snapshot persistence // - [aix.WithSnapshotCallback]: Control when snapshots are created // - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types +// - [aix.WithStateTransform]: Rewrite session state on its way out to the client // -// Example: +// Example (inline prompt): // -// chatAgent := genkit.DefineAgent[any](g, "chat", -// ai.WithModelName("googleai/gemini-3-flash-preview"), -// ai.WithSystem("You are a helpful assistant."), +// chatAgent := genkit.DefineAgent(g, "chat", +// aix.FromInline( +// ai.WithModelName("googleai/gemini-3-flash-preview"), +// ai.WithSystem("You are a helpful assistant."), +// ), // aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), // ) // -// conn, err := chatAgent.StreamBidi(ctx) -// if err != nil { -// // handle error -// } -// conn.SendText("Hello!") -// for chunk, err := range conn.Receive() { -// if chunk.TurnEnd != nil { -// break -// } -// fmt.Print(chunk.ModelChunk.Text()) -// } -// conn.Close() -func DefineAgent[State any]( - g *Genkit, - name string, - opts ...aix.AgentDefineOption[State], -) *aix.Agent[any, State] { - return aix.DefineAgent(g.reg, name, opts...) -} - -// DefinePromptAgent defines an agent backed by a prompt already registered -// with the registry (via [DefinePrompt] or loaded from a .prompt file). The -// agent is registered under the same name as the prompt. -// -// Experimental: This API is under active development and may change in any -// minor version release. -// -// defaultInput is used to render the prompt on every turn. The prompt's -// Render is invoked once at definition time as a smoke check, so a -// defaultInput that fails the prompt's input schema panics here rather -// than failing on the first invocation. -// -// For an agent that defines its prompt inline, use [DefineAgent]. For full -// control over the per-turn loop, use [DefineCustomAgent]. -// -// Type parameters: -// - State: Type for user-defined state persisted in snapshots -// -// Example: +// Example (existing prompt): // // type ChatInput struct { // Personality string `json:"personality"` // } // -// chatAgent := genkit.DefinePromptAgent[any](g, "chat", -// ChatInput{Personality: "a sarcastic pirate"}, +// chatAgent := genkit.DefineAgent(g, "chat", +// aix.FromPrompt(ChatInput{Personality: "a sarcastic pirate"}), // aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), // ) -func DefinePromptAgent[State any]( +func DefineAgent[State any]( g *Genkit, - promptName string, - defaultInput any, + name string, + source aix.AgentSource, opts ...aix.AgentOption[State], ) *aix.Agent[any, State] { - return aix.DefinePromptAgent(g.reg, promptName, defaultInput, opts...) + return aix.DefineAgent(g.reg, name, source, opts...) } // DefineCustomAgent defines an agent with full control over the conversation @@ -536,8 +504,8 @@ func DefinePromptAgent[State any]( // Call [aix.SessionRunner.Run] to enter the turn loop, which blocks until the // client sends the next message. // -// For agents backed by a prompt, use [DefineAgent] (inline) or -// [DefinePromptAgent] (existing prompt) instead. +// For agents backed by a prompt, use [DefineAgent] with [aix.FromInline] +// (inline prompt) or [aix.FromPrompt] (existing prompt) instead. // // # Options // diff --git a/go/samples/agent-custom/main.go b/go/samples/agent-custom/main.go index 9b0d10ea94..cc5488023b 100644 --- a/go/samples/agent-custom/main.go +++ b/go/samples/agent-custom/main.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates DefineCustomAgent, which gives the caller full -// control over the per-turn loop (model selection, history management, -// streaming chunks). It runs a CLI REPL where conversation history is -// maintained by the session. Compare with agent-prompt (DefinePromptAgent) -// and agent-inline (DefineAgent), which both auto-wire the loop and -// differ only in where the prompt is defined. +// This sample demonstrates DefineCustomAgent, which gives the caller +// full control over the per-turn loop (model selection, history +// management, streaming chunks). It runs a CLI REPL where conversation +// history is maintained by the session. Compare with agent-prompt +// (DefineAgent + aix.FromPrompt) and agent-inline (DefineAgent + +// aix.FromInline), which both auto-wire the loop and differ only in +// where the prompt is defined. package main import ( diff --git a/go/samples/agent-inline/main.go b/go/samples/agent-inline/main.go index d33845a4b1..a570254d8d 100644 --- a/go/samples/agent-inline/main.go +++ b/go/samples/agent-inline/main.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates DefineAgent, which creates a multi-turn -// conversational agent backed by a prompt defined inline alongside the -// agent. The conversation loop (render prompt, call model, stream chunks, -// update history) is handled automatically. Compare with agent-prompt -// (DefinePromptAgent), which sources its prompt from a .prompt file, and -// agent-custom (DefineCustomAgent), which wires the same loop manually. +// This sample demonstrates DefineAgent with aix.FromInline, which +// creates a multi-turn conversational agent backed by a prompt defined +// inline alongside the agent. The conversation loop (render prompt, +// call model, stream chunks, update history) is handled automatically. +// Compare with agent-prompt (DefineAgent + aix.FromPrompt), which +// sources its prompt from a .prompt file, and agent-custom +// (DefineCustomAgent), which wires the same loop manually. package main import ( @@ -38,13 +39,15 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatAgent := genkit.DefineAgent[any](g, "chat", - ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), + chatAgent := genkit.DefineAgent(g, "chat", + aix.FromInline( + ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), + ), aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) diff --git a/go/samples/agent-prompt/main.go b/go/samples/agent-prompt/main.go index 2a001df874..9d66196f20 100644 --- a/go/samples/agent-prompt/main.go +++ b/go/samples/agent-prompt/main.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates DefinePromptAgent, which creates a multi-turn -// conversational agent backed by a .prompt file. The conversation loop -// (render prompt, call model, stream chunks, update history) is handled -// automatically. Compare with agent-custom (DefineCustomAgent), which -// wires the same loop manually, and agent-inline (DefineAgent), which -// defines the prompt inline alongside the agent. +// This sample demonstrates DefineAgent with aix.FromPrompt, which +// creates a multi-turn conversational agent backed by a .prompt file. +// The conversation loop (render prompt, call model, stream chunks, +// update history) is handled automatically. Compare with agent-custom +// (DefineCustomAgent), which wires the same loop manually, and +// agent-inline (DefineAgent + aix.FromInline), which defines the +// prompt inline alongside the agent. package main import ( @@ -40,8 +41,8 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - chatAgent := genkit.DefinePromptAgent(g, "chat", - ChatPromptInput{Personality: "a sarcastic pirate"}, + chatAgent := genkit.DefineAgent(g, "chat", + aix.FromPrompt(ChatPromptInput{Personality: "a sarcastic pirate"}), aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 From 43062f786ccba6d2f110becff9a43af85faec141 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 17:07:25 -0700 Subject: [PATCH 074/141] refactor(go/exp): single-message AgentInput, validate user role MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AgentInput.messages was an array but per-turn input is conceptually one message. Renamed to AgentInput.message (single, optional) across the schema and the Go runtime; AgentConnection.SendMessages becomes SendMessage. The DefineAgent loop now rejects inputs whose role is not "user" or whose content carries tool request / response parts — those belong on AgentInput.Resume, not on a turn message. --- genkit-tools/common/src/types/agent.ts | 4 +- genkit-tools/genkit-schema.json | 7 +- go/ai/exp/agent.go | 52 ++++++++--- go/ai/exp/agent_test.go | 116 +++++++++++++++++++------ go/ai/exp/gen.go | 6 +- go/core/schemas.config | 8 +- 6 files changed, 141 insertions(+), 52 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 8753873815..0164f4b1db 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -100,8 +100,8 @@ export const AgentInputSchema = z.object({ * finalized (or the snapshot is aborted via `abortSnapshot`). */ detach: z.boolean().optional(), - /** User's input messages for this turn. */ - messages: z.array(MessageSchema).optional(), + /** User's input message for this turn. */ + message: MessageSchema.optional(), /** Options for resuming an interrupted generation. */ resume: z .object({ diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index fda6d1eabc..e77e5cf63a 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -46,11 +46,8 @@ "detach": { "type": "boolean" }, - "messages": { - "type": "array", - "items": { - "$ref": "#/$defs/Message" - } + "message": { + "$ref": "#/$defs/Message" }, "resume": { "type": "object", diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 1bd008e671..7767b925d8 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -89,7 +89,9 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont } _, err := tracing.RunInNewSpan(ctx, spanMeta, input, func(ctx context.Context, input *AgentInput) (any, error) { - s.AddMessages(input.Messages...) + if input.Message != nil { + s.AddMessages(input.Message) + } if err := fn(ctx, input); err != nil { return nil, err } @@ -977,7 +979,7 @@ func (i *detachIntake) enqueue(input *AgentInput) { // rather than enqueued: it carries no payload to process, so it would // just trigger a no-op turn. Callers that want to ride a final input // on the detach signal can do so by calling -// Send(&AgentInput{Detach: true, Messages: ...}) explicitly. +// Send(&AgentInput{Detach: true, Message: ...}) explicitly. func (i *detachIntake) handleDetach(first *AgentInput) { var drained []*AgentInput if hasInputPayload(first) { @@ -1016,7 +1018,7 @@ func hasInputPayload(in *AgentInput) bool { if in == nil { return false } - if len(in.Messages) > 0 { + if in.Message != nil { return true } if in.Resume != nil && (len(in.Resume.Respond) > 0 || len(in.Resume.Restart) > 0) { @@ -1132,6 +1134,30 @@ func (i *detachIntake) stopAndWait() { // excluded from session history after generation. const promptMessageKey = "_genkit_prompt" +// validateUserMessage rejects inputs the prompt-backed agent loop can't +// safely consume: a non-user role would be appended to history under the +// wrong speaker, and tool request / response parts belong on the +// [AgentInput.Resume] payload, not on a turn message. +func validateUserMessage(m *ai.Message) error { + if m == nil { + return nil + } + if m.Role != "" && m.Role != ai.RoleUser { + return core.NewError(core.INVALID_ARGUMENT, + "agent input message must have role %q, got %q", ai.RoleUser, m.Role) + } + for _, p := range m.Content { + if p == nil { + continue + } + if p.IsToolRequest() || p.IsToolResponse() { + return core.NewError(core.INVALID_ARGUMENT, + "agent input message must not contain tool request or response parts; use AgentInput.Resume instead") + } + } + return nil +} + // agentLoop returns the per-turn function for a prompt-backed agent. Each // turn renders the prompt, appends conversation history, calls the model // with streaming, and updates the session. @@ -1142,6 +1168,10 @@ const promptMessageKey = "_genkit_prompt" func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[any, State] { return func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + if err := validateUserMessage(input.Message); err != nil { + return err + } + actionOpts, err := prompt.Render(ctx, defaultInput) if err != nil { return fmt.Errorf("prompt render: %w", err) @@ -1277,14 +1307,14 @@ func (a *Agent[Stream, State]) Run( // RunText is a convenience method that starts a single-turn agent invocation // with a user text message. It is equivalent to calling Run with an -// AgentInput containing a single user text message. +// AgentInput whose Message is a user text message. func (a *Agent[Stream, State]) RunText( ctx context.Context, text string, opts ...InvocationOption[State], ) (*AgentOutput[State], error) { return a.Run(ctx, &AgentInput{ - Messages: []*ai.Message{ai.NewUserTextMessage(text)}, + Message: ai.NewUserTextMessage(text), }, opts...) } @@ -1342,15 +1372,15 @@ func (c *AgentConnection[Stream, State]) Send(input *AgentInput) error { return c.conn.Send(input) } -// SendMessages sends messages to the agent. -func (c *AgentConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { - return c.conn.Send(&AgentInput{Messages: messages}) +// SendMessage sends a message to the agent for one turn. +func (c *AgentConnection[Stream, State]) SendMessage(message *ai.Message) error { + return c.conn.Send(&AgentInput{Message: message}) } -// SendText sends a single user text message to the agent. +// SendText sends a user text message to the agent. func (c *AgentConnection[Stream, State]) SendText(text string) error { return c.conn.Send(&AgentInput{ - Messages: []*ai.Message{ai.NewUserTextMessage(text)}, + Message: ai.NewUserTextMessage(text), }) } @@ -1375,7 +1405,7 @@ func (c *AgentConnection[Stream, State]) SendResume(resume *ToolResume) error { // session and end up in the final snapshot's state. // // To send a final input as part of the same wire message, use -// Send(&AgentInput{Detach: true, Messages: ...}) directly. +// Send(&AgentInput{Detach: true, Message: ...}) directly. func (c *AgentConnection[Stream, State]) Detach() error { return c.conn.Send(&AgentInput{Detach: true}) } diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 2712f572b4..785208fcfe 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -53,8 +53,8 @@ func TestAgent_BasicMultiTurn(t *testing.T) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. - if len(input.Messages) > 0 { - reply := ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text) + if input.Message != nil { + reply := ai.NewModelTextMessage("echo: " + input.Message.Content[0].Text) sess.AddMessages(reply) } sess.UpdateCustom(func(s testState) testState { @@ -127,7 +127,7 @@ func TestAgent_WithSessionStore(t *testing.T) { af := DefineCustomAgent(reg, "snapshotFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - if len(input.Messages) > 0 { + if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } sess.UpdateCustom(func(s testState) testState { @@ -197,7 +197,7 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { af := DefineCustomAgent(reg, "resumeFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - if len(input.Messages) > 0 { + if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } sess.UpdateCustom(func(s testState) testState { @@ -286,7 +286,7 @@ func TestAgent_ClientManagedState(t *testing.T) { af := DefineCustomAgent(reg, "clientStateFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - if len(input.Messages) > 0 { + if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } sess.UpdateCustom(func(s testState) testState { @@ -471,11 +471,11 @@ func TestAgent_SnapshotCallback(t *testing.T) { } } -func TestAgent_SendMessages(t *testing.T) { +func TestAgent_SendMessage(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "sendMsgsFlow", + af := DefineCustomAgent(reg, "sendMsgFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { return nil @@ -488,13 +488,10 @@ func TestAgent_SendMessages(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - // Send multiple messages at once. - err = conn.SendMessages( - ai.NewUserTextMessage("msg1"), - ai.NewUserTextMessage("msg2"), - ) + // Send a message via SendMessage. + err = conn.SendMessage(ai.NewUserTextMessage("msg1")) if err != nil { - t.Fatalf("SendMessages failed: %v", err) + t.Fatalf("SendMessage failed: %v", err) } for chunk, err := range conn.Receive() { if err != nil { @@ -511,9 +508,9 @@ func TestAgent_SendMessages(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // Both messages should have been added. - if got := len(response.State.Messages); got != 2 { - t.Errorf("expected 2 messages, got %d", got) + // The message should have been added. + if got := len(response.State.Messages); got != 1 { + t.Errorf("expected 1 message, got %d", got) } } @@ -1336,8 +1333,8 @@ func TestAgent_RunText(t *testing.T) { af := DefineCustomAgent(reg, "runTextFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - if len(input.Messages) > 0 { - sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text)) + if input.Message != nil { + sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Message.Content[0].Text)) } sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1369,7 +1366,7 @@ func TestAgent_Run(t *testing.T) { af := DefineCustomAgent(reg, "runFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - if len(input.Messages) > 0 { + if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } return nil @@ -1378,10 +1375,7 @@ func TestAgent_Run(t *testing.T) { ) input := &AgentInput{ - Messages: []*ai.Message{ - ai.NewUserTextMessage("msg1"), - ai.NewUserTextMessage("msg2"), - }, + Message: ai.NewUserTextMessage("msg1"), } response, err := af.Run(ctx, input) @@ -1389,9 +1383,9 @@ func TestAgent_Run(t *testing.T) { t.Fatalf("Run failed: %v", err) } - // 2 user messages + 1 reply = 3. - if got := len(response.State.Messages); got != 3 { - t.Errorf("expected 3 messages, got %d", got) + // 1 user message + 1 reply = 2. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) } } @@ -1507,6 +1501,74 @@ func TestPromptAgent_RunText(t *testing.T) { } } +func TestPromptAgent_RejectsNonUserRole(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + ai.DefinePrompt(reg, "rejectRolePrompt", ai.WithModelName("test/echo")) + af := DefineAgent[testState](reg, "rejectRolePrompt", FromPrompt()) + + _, err := af.Run(ctx, &AgentInput{ + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart("hi")}, + }, + }) + if err == nil { + t.Fatal("expected error for non-user role, got nil") + } + if !strings.Contains(err.Error(), "role") { + t.Errorf("expected role-related error, got %v", err) + } +} + +func TestPromptAgent_RejectsToolRequestPart(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + ai.DefinePrompt(reg, "rejectToolReqPrompt", ai.WithModelName("test/echo")) + af := DefineAgent[testState](reg, "rejectToolReqPrompt", FromPrompt()) + + _, err := af.Run(ctx, &AgentInput{ + Message: &ai.Message{ + Role: ai.RoleUser, + Content: []*ai.Part{ + ai.NewTextPart("hi"), + ai.NewToolRequestPart(&ai.ToolRequest{Name: "doThing", Ref: "1"}), + }, + }, + }) + if err == nil { + t.Fatal("expected error for tool request part, got nil") + } + if !strings.Contains(err.Error(), "tool request") { + t.Errorf("expected tool-request error, got %v", err) + } +} + +func TestPromptAgent_RejectsToolResponsePart(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + ai.DefinePrompt(reg, "rejectToolRespPrompt", ai.WithModelName("test/echo")) + af := DefineAgent[testState](reg, "rejectToolRespPrompt", FromPrompt()) + + _, err := af.Run(ctx, &AgentInput{ + Message: &ai.Message{ + Role: ai.RoleUser, + Content: []*ai.Part{ + ai.NewToolResponsePart(&ai.ToolResponse{Name: "doThing", Ref: "1"}), + }, + }, + }) + if err == nil { + t.Fatal("expected error for tool response part, got nil") + } + if !strings.Contains(err.Error(), "tool") { + t.Errorf("expected tool-related error, got %v", err) + } +} + func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) @@ -1868,7 +1930,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { entered <- struct{}{} <-release - sess.AddMessages(ai.NewModelTextMessage("reply-" + input.Messages[0].Text())) + sess.AddMessages(ai.NewModelTextMessage("reply-" + input.Message.Text())) sess.UpdateCustom(func(s testState) testState { s.Counter++ return s diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 20c78a5f26..59427361ba 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -71,12 +71,12 @@ type AgentInput struct { // queued inputs are processed (or the snapshot is cancelled via // cancelSnapshot). Detach bool `json:"detach,omitempty"` - // Messages contains the user's input for this turn. - Messages []*ai.Message `json:"messages,omitempty"` + // Message is the user's input for this turn. + Message *ai.Message `json:"message,omitempty"` // Resume provides options for resuming an interrupted generation. // Construct using [ai.ToolDef.RestartWith] / [ai.ToolDef.RespondWith] // parts. When set, the generate call resumes with these parts instead - // of treating Messages as tool responses. + // of treating Message as a tool response. Resume *ToolResume `json:"resume,omitempty"` } diff --git a/go/core/schemas.config b/go/core/schemas.config index b55332ad0d..8e82cd3331 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1219,16 +1219,16 @@ queued inputs are processed (or the snapshot is cancelled via cancelSnapshot). . -AgentInput.messages type []*ai.Message -AgentInput.messages doc -Messages contains the user's input for this turn. +AgentInput.message type *ai.Message +AgentInput.message doc +Message is the user's input for this turn. . AgentInput.resume doc Resume provides options for resuming an interrupted generation. Construct using [ai.ToolDef.RestartWith] / [ai.ToolDef.RespondWith] parts. When set, the generate call resumes with these parts instead -of treating Messages as tool responses. +of treating Message as a tool response. . # AgentInputResume is the inline resume payload hoisted out of From 6cf858cf0947ea0dbb0f7c521da2743c81425914 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 17:09:52 -0700 Subject: [PATCH 075/141] chore(py): regenerate _typing.py for AgentInput.message rename --- py/packages/genkit/src/genkit/_core/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index c652a5f563..5e9bd96343 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -131,7 +131,7 @@ class AgentInput(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) detach: bool | None = None - messages: list[MessageData] | None = None + message: MessageData | None = None resume: Resume | None = None From acdb8e69ae45f7a3df6a9887b00f613b71056e08 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 12 May 2026 17:37:07 -0700 Subject: [PATCH 076/141] fix(go/exp): ctx-aware Responder.Send to decouple fn liveness from shutdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Send methods on Responder previously did a bare channel send into the chunk router. If the router was pinned on a downstream send to a slow/gone consumer when workCtx cancelled, fn blocked on the send until the runtime's terminal path called router.stopAndWait() to flip the router into drain mode. It worked in practice — every terminal path calls stopAndWait before awaiting fnDone — but tied fn liveness to that invariant; a future terminal path that forgot stopAndWait would deadlock with no test to catch it. Plumb workCtx through Responder and select on it in every Send. A Send issued after cancellation drops the chunk and returns immediately regardless of router state. emitTurnEnd's internal send moves to the same shape via a new router.sendChunk helper. Companion change in handleFnDone: fn can now return with ctx.Err() while the router is still pinned on r.out, so close(r.in) alone wouldn't unblock it. Call router.stopAndWait() before close when res.err != nil. The natural-completion path deliberately skips this so a last in-flight chunk to a slow-but-alive consumer is never trashed. Public call sites are unchanged — resp.SendModelChunk(chunk) etc. behave identically when ctx is alive. --- go/ai/exp/agent.go | 67 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 7767b925d8..23e4d1f593 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -211,16 +211,25 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot // Responder is the output channel for an agent. Artifacts sent through // it are automatically added to the session before being forwarded to the // client. -type Responder[Stream any] chan<- *AgentStreamChunk[Stream] +// +// All Send methods are ctx-aware: if the agent's work context is +// cancelled (typically client disconnect, abort during detach, or fn +// completion), Send returns promptly with the chunk dropped. Send itself +// remains fire-and-forget and returns no error; the user fn is expected +// to observe cancellation through its own ctx check and stop producing. +type Responder[Stream any] struct { + in chan<- *AgentStreamChunk[Stream] + ctx context.Context +} // SendModelChunk sends a generation chunk (token-level streaming). func (r Responder[Stream]) SendModelChunk(chunk *ai.ModelResponseChunk) { - r <- &AgentStreamChunk[Stream]{ModelChunk: chunk} + r.send(&AgentStreamChunk[Stream]{ModelChunk: chunk}) } // SendStatus sends a user-defined status update. func (r Responder[Stream]) SendStatus(status Stream) { - r <- &AgentStreamChunk[Stream]{Status: status} + r.send(&AgentStreamChunk[Stream]{Status: status}) } // SendArtifact sends an artifact to the stream and adds it to the session. @@ -229,7 +238,19 @@ func (r Responder[Stream]) SendStatus(status Stream) { // has landed; only the wire forward to the client is suppressed // post-detach, when there is no longer a client to receive it. func (r Responder[Stream]) SendArtifact(artifact *Artifact) { - r <- &AgentStreamChunk[Stream]{Artifact: artifact} + r.send(&AgentStreamChunk[Stream]{Artifact: artifact}) +} + +// send delivers chunk to the router, returning promptly if r.ctx is +// cancelled. Dropping on cancel decouples fn liveness from the runtime's +// shutdown choreography: a Send issued after workCtx cancellation +// completes immediately rather than blocking on a router that has not +// yet been put into drain mode by a terminal path. +func (r Responder[Stream]) send(chunk *AgentStreamChunk[Stream]) { + select { + case r.in <- chunk: + case <-r.ctx.Done(): + } } // --- Agent --- @@ -411,9 +432,9 @@ func newAgentRuntime[Stream, State any]( // chunk through the router so clients see it on the output stream. func (rt *agentRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd) - rt.router.send() <- &AgentStreamChunk[Stream]{TurnEnd: &TurnEnd{ + rt.router.sendChunk(ctx, &AgentStreamChunk[Stream]{TurnEnd: &TurnEnd{ SnapshotID: snapshotID, - }} + }}) } // run drives the user fn to completion and returns the agent output. @@ -456,7 +477,7 @@ func (rt *agentRuntime[Stream, State]) run( fnErr = core.NewError(core.INTERNAL, "agent fn panicked: %v", r) } }() - result, fnErr = fn(workCtx, rt.router.responder(), rt.sess) + result, fnErr = fn(workCtx, rt.router.responder(workCtx), rt.sess) }() rt.fnDone <- fnDoneResult[State]{result: result, err: fnErr} }() @@ -519,6 +540,15 @@ func (rt *agentRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFun // handleFnDone is the synchronous-completion path: fn returned before any // detach signal. Capture an invocation-end snapshot if state advanced past // the last turn-end snapshot, then assemble the output. +// +// When fn returns with an error, the Responder's ctx-aware send may have +// dropped a chunk while the router was still pinned on a downstream send +// to a slow/gone consumer. router.close blocks on the router's forward +// goroutine exiting, which can't happen while it's stuck on that send; +// stopAndWait closes stopWriting first so the router breaks out and +// enters drain mode. The natural-completion path leaves the router idle +// (every Send was accepted before fn returned), so close alone is +// sufficient there and avoids trashing a last in-flight chunk. func (rt *agentRuntime[Stream, State]) handleFnDone( ctx context.Context, cancelWork context.CancelFunc, @@ -526,6 +556,9 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( ) (*AgentOutput[State], error) { cancelWork() rt.intake.stopAndWait() + if res.err != nil { + rt.router.stopAndWait() + } rt.router.close() if res.err != nil { @@ -819,15 +852,21 @@ func (r *chunkRouter[Stream, State]) forward() bool { } } -// responder returns a [Responder] that sends chunks into the router. -func (r *chunkRouter[Stream, State]) responder() Responder[Stream] { - return Responder[Stream](r.in) +// responder returns a [Responder] that sends chunks into the router. The +// returned Responder's Send methods drop chunks (returning promptly) +// when ctx is cancelled. +func (r *chunkRouter[Stream, State]) responder(ctx context.Context) Responder[Stream] { + return Responder[Stream]{in: r.in, ctx: ctx} } -// send returns the internal chunk channel for producers other than the user -// agent function (e.g. the runtime's emitTurnEnd). -func (r *chunkRouter[Stream, State]) send() chan<- *AgentStreamChunk[Stream] { - return r.in +// sendChunk delivers chunk to the router for producers other than the +// user agent function (e.g. the runtime's emitTurnEnd). Returns +// promptly if ctx is cancelled, dropping the chunk. +func (r *chunkRouter[Stream, State]) sendChunk(ctx context.Context, chunk *AgentStreamChunk[Stream]) { + select { + case r.in <- chunk: + case <-ctx.Done(): + } } // collectTurnChunks returns and resets accumulated turn chunks. From 3888066caf66f3ba30e1cde653e8131af71f1ac3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 13 May 2026 09:06:15 -0700 Subject: [PATCH 077/141] refactor(go/core,go/exp): break-out of BidiConnection.Receive no longer cancels Previously, breaking out of the BidiConnection.Receive iterator called c.cancel(), terminating the entire bidi action. This was modelled on Go's iter convention for "owning" iterators, but bidi streams are not owned by their receive iterator: send and receive are independent phases on a long-lived connection, and callers routinely break receive to switch to sending. This is the same shape as gRPC, WebSockets, NATS, and similar bidi protocols. Drop the c.cancel() so break is purely an iterator exit. Lifecycle is now controlled by ctx and Close, matching industry convention. This eliminates the workaround AgentConnection needed: a drainer goroutine that read the underlying iterator and buffered chunks into a private channel so the user could break between turns without killing the connection. With break-doesn't-cancel, AgentConnection.Receive collapses to a direct delegation to c.conn.Receive, and the chunks channel, chunkErr, initOnce, and initReceiver all disappear. One fewer goroutine per agent connection, less state, fewer moving parts. Breaking change for non-agent BidiConnection callers that relied on break-cancels-connection; they should switch to ctx cancellation or Close. No such callers exist in the tree (all existing bidi tests drain to completion). --- go/ai/exp/agent.go | 53 +++++++++------------------------------------- go/core/action.go | 8 +++++-- 2 files changed, 16 insertions(+), 45 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 23e4d1f593..1bc40d5885 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -1374,36 +1374,12 @@ func (a *Agent[Stream, State]) resolveOptions(opts []InvocationOption[State]) (* // --- AgentConnection --- -// AgentConnection wraps BidiConnection with agent-specific functionality. -// It provides a Receive() iterator that supports multi-turn patterns: breaking -// out of the iterator between turns does not cancel the underlying connection. +// AgentConnection wraps BidiConnection with agent-specific Send helpers +// (SendMessage / SendText / SendResume / Detach) and an Output that +// always waits for finalization (so detached invocations see the +// pending snapshot ID rather than a context-cancellation error). type AgentConnection[Stream, State any] struct { conn *core.BidiConnection[*AgentInput, *AgentStreamChunk[Stream], *AgentOutput[State]] - - // chunks buffers stream chunks from the underlying connection so that - // breaking from Receive() between turns doesn't cancel the context. - chunks chan *AgentStreamChunk[Stream] - chunkErr error - initOnce sync.Once -} - -// initReceiver starts a goroutine that drains the underlying BidiConnection's -// Receive into a channel. This goroutine never breaks from the underlying -// iterator, preventing context cancellation. -func (c *AgentConnection[Stream, State]) initReceiver() { - c.initOnce.Do(func() { - c.chunks = make(chan *AgentStreamChunk[Stream], 1) - go func() { - defer close(c.chunks) - for chunk, err := range c.conn.Receive() { - if err != nil { - c.chunkErr = err - return - } - c.chunks <- chunk - } - }() - }) } // Send sends an AgentInput to the agent. @@ -1454,22 +1430,13 @@ func (c *AgentConnection[Stream, State]) Close() error { return c.conn.Close() } -// Receive returns an iterator for receiving stream chunks. -// Unlike the underlying BidiConnection.Receive, breaking out of this iterator -// does not cancel the connection. This enables multi-turn patterns where the -// caller breaks on TurnEnd, sends the next input, then calls Receive again. +// Receive returns an iterator for receiving stream chunks. Breaking out +// of the iterator does not cancel the connection; multi-turn callers +// routinely break on [TurnEnd], send the next input, then call Receive +// again to consume the next batch. Use ctx cancellation or [Close] to +// terminate the connection. func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[Stream], error] { - c.initReceiver() - return func(yield func(*AgentStreamChunk[Stream], error) bool) { - for chunk := range c.chunks { - if !yield(chunk, nil) { - return - } - } - if err := c.chunkErr; err != nil { - yield(nil, err) - } - } + return c.conn.Receive() } // Output returns the final response after the agent completes. diff --git a/go/core/action.go b/go/core/action.go index 11b8336300..7b62c20b71 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -575,7 +575,12 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Close() error { } // Receive returns an iterator for receiving streamed response chunks. -// The iterator completes when the action finishes. +// The iterator yields chunks until the action finishes, the context is +// cancelled, or the caller breaks out of the loop. Breaking out does NOT +// cancel the connection: bidi callers routinely break to switch to +// sending, then call Receive again to consume the next batch. Use ctx +// cancellation or [BidiConnection.Close] to terminate the connection +// (matching gRPC and similar bidi streaming conventions). func (c *BidiConnection[StreamIn, StreamOut, Out]) Receive() iter.Seq2[StreamOut, error] { return func(yield func(StreamOut, error) bool) { for { @@ -585,7 +590,6 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Receive() iter.Seq2[StreamOut return } if !yield(chunk, nil) { - c.cancel() return } case <-c.ctx.Done(): From 934e36d1a64eefa16e467ebddbf6620427f4cd33 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 13 May 2026 09:19:27 -0700 Subject: [PATCH 078/141] chore(go/samples/agent-prompt): reference ChatPromptInput by name Register the input schema via DefineSchemaFor and reference it by name from chat.prompt instead of inlining the field list. Keeps the prompt file in sync with the Go type. --- go/samples/agent-prompt/main.go | 2 ++ go/samples/agent-prompt/prompts/chat.prompt | 5 +---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/go/samples/agent-prompt/main.go b/go/samples/agent-prompt/main.go index 9d66196f20..e271eb17f0 100644 --- a/go/samples/agent-prompt/main.go +++ b/go/samples/agent-prompt/main.go @@ -41,6 +41,8 @@ func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + genkit.DefineSchemaFor[ChatPromptInput](g) + chatAgent := genkit.DefineAgent(g, "chat", aix.FromPrompt(ChatPromptInput{Personality: "a sarcastic pirate"}), aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), diff --git a/go/samples/agent-prompt/prompts/chat.prompt b/go/samples/agent-prompt/prompts/chat.prompt index 6a78a99b07..4c11a18f72 100644 --- a/go/samples/agent-prompt/prompts/chat.prompt +++ b/go/samples/agent-prompt/prompts/chat.prompt @@ -4,9 +4,6 @@ config: thinkingConfig: thinkingBudget: 0 input: - schema: - personality: string - default: - personality: a helpful assistant + schema: ChatPromptInput --- You are {{personality}}. Keep responses concise. From 50ebd0faaf72e68fd327434006b5d98a3d87542e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 13 May 2026 09:19:31 -0700 Subject: [PATCH 079/141] chore(go/samples/agent-inline): use gemini-flash-latest --- go/samples/agent-inline/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/samples/agent-inline/main.go b/go/samples/agent-inline/main.go index a570254d8d..12ccca30a8 100644 --- a/go/samples/agent-inline/main.go +++ b/go/samples/agent-inline/main.go @@ -41,7 +41,7 @@ func main() { chatAgent := genkit.DefineAgent(g, "chat", aix.FromInline( - ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, From 3e29774695f2fcbfb69ccc663f692cfbc2acc800 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 13 May 2026 09:38:24 -0700 Subject: [PATCH 080/141] fix(go/exp): deep-copy Result and AgentOutput to isolate from session SessionRunner.Result returned shallow-copied pointers: the slice header for Artifacts was fresh but the *ai.Message and *Artifact elements were shared with session state. A caller doing result.Message.Content[0].Text = "..." would mutate the session's actual message. handleFnDone then forwarded those same pointers into AgentOutput, exposing the same vector to the framework's caller. Add jsonClone and cloneArtifacts helpers in session.go (alongside the existing copySnapshot JSON-copy pattern) and use them in two places: * SessionRunner.Result deep-copies the last message and all artifacts on its way out. fn-side mutations of the returned result no longer reach session state. * handleFnDone deep-copies again when assembling AgentOutput. Defense in depth: even if a custom fn constructs AgentResult manually with raw session pointers (bypassing Result), the framework's caller still gets an isolated copy. Cost is small relative to the existing per-snapshot JSON copy: just the last message and the artifact list, not the whole history. Regression test asserts both invariants in one go. --- go/ai/exp/agent.go | 19 +++++---- go/ai/exp/agent_test.go | 89 +++++++++++++++++++++++++++++++++++++++++ go/ai/exp/session.go | 31 ++++++++++++++ 3 files changed, 132 insertions(+), 7 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 1bc40d5885..e1c56e8684 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -111,7 +111,10 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont } // Result returns an [AgentResult] populated from the current session state: -// the last message in the conversation history and all artifacts. +// the last message in the conversation history and all artifacts. The +// returned value is independent of the session; callers may mutate it +// without affecting session state. +// // It is a convenience for custom agents that don't need to construct the // result manually. func (s *SessionRunner[State]) Result() *AgentResult { @@ -120,12 +123,10 @@ func (s *SessionRunner[State]) Result() *AgentResult { result := &AgentResult{} if msgs := s.state.Messages; len(msgs) > 0 { - result.Message = msgs[len(msgs)-1] + result.Message = jsonClone(msgs[len(msgs)-1]) } if len(s.state.Artifacts) > 0 { - arts := make([]*Artifact, len(s.state.Artifacts)) - copy(arts, s.state.Artifacts) - result.Artifacts = arts + result.Artifacts = cloneArtifacts(s.state.Artifacts) } return result } @@ -574,8 +575,12 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( out := &AgentOutput[State]{SnapshotID: snapshotID} if res.result != nil { - out.Message = res.result.Message - out.Artifacts = res.result.Artifacts + // Deep-copy at the framework boundary so the caller cannot + // mutate session contents through the returned output, even + // if a custom fn constructed AgentResult with raw session + // pointers rather than going through [SessionRunner.Result]. + out.Message = jsonClone(res.result.Message) + out.Artifacts = cloneArtifacts(res.result.Artifacts) } if rt.cfg.store == nil { out.State = applyTransform(ctx, rt.cfg.transform, rt.session.State()) diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 785208fcfe..d3c352c07d 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -3300,3 +3300,92 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { t.Errorf("snapshot status = %q after abort-on-terminal, want succeeded", snap.Status) } } + +func TestAgent_ResultAndOutput_IsolatedFromSession(t *testing.T) { + // Result() and AgentOutput must contain deep copies of session state so + // neither the fn (after calling Result) nor the caller (after receiving + // Output) can mutate session contents through them. Both layers + // deep-copy: Result for fn-side ergonomics, handleFnDone for defense + // in depth in case fn returns AgentResult built with raw session + // pointers instead of going through Result(). + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + var ( + sessionMsgAfterMutation string + sessionArtAfterMutation string + fnReturnedMessage *ai.Message + fnReturnedArtifact *Artifact + ) + + af := DefineCustomAgent(reg, "isolation", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + sess.AddMessages(ai.NewModelTextMessage("session-msg")) + sess.AddArtifacts(&Artifact{ + Name: "orig", + Parts: []*ai.Part{ai.NewTextPart("orig-part")}, + }) + return nil + }); err != nil { + return nil, err + } + + result := sess.Result() + // Mutate the returned result; must not touch session state. + result.Message.Content[0].Text = "fn-tainted-msg" + result.Artifacts[0].Name = "fn-tainted-art" + + // Capture the session view AFTER mutation so the outer test + // can verify the mutation didn't bleed through. + msgs := sess.Messages() + sessionMsgAfterMutation = msgs[len(msgs)-1].Content[0].Text + arts := sess.Artifacts() + sessionArtAfterMutation = arts[0].Name + + // Capture the pointers fn is returning so the outer test + // can verify handleFnDone copied them (i.e., out.Message + // is not the same pointer as what fn handed back). + fnReturnedMessage = result.Message + fnReturnedArtifact = result.Artifacts[0] + return result, nil + }, + WithSessionStore(store), + ) + + out, err := af.RunText(context.Background(), "go") + if err != nil { + t.Fatalf("RunText: %v", err) + } + + // Result() must have given fn an isolated copy. + if sessionMsgAfterMutation != "session-msg" { + t.Errorf("session message tainted by fn mutation of Result(): got %q, want %q", + sessionMsgAfterMutation, "session-msg") + } + if sessionArtAfterMutation != "orig" { + t.Errorf("session artifact tainted by fn mutation of Result(): got %q, want %q", + sessionArtAfterMutation, "orig") + } + + // handleFnDone must have copied fn's returned pointers at the framework + // boundary, so caller-side mutations cannot reach what fn handed back. + if out.Message == fnReturnedMessage { + t.Error("AgentOutput.Message shares pointer with fn's returned message; handleFnDone defensive copy missing") + } + if len(out.Artifacts) > 0 && out.Artifacts[0] == fnReturnedArtifact { + t.Error("AgentOutput.Artifacts[0] shares pointer with fn's returned artifact; handleFnDone defensive copy missing") + } + + // The persisted snapshot must reflect the un-tainted session state. + snap, err := store.GetSnapshot(context.Background(), out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if got := snap.State.Messages[len(snap.State.Messages)-1].Content[0].Text; got != "session-msg" { + t.Errorf("snapshot message tainted: got %q, want %q", got, "session-msg") + } + if snap.State.Artifacts[0].Name != "orig" { + t.Errorf("snapshot artifact tainted: got %q, want %q", snap.State.Artifacts[0].Name, "orig") + } +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 6b1165dc68..2870ee4a59 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -334,6 +334,37 @@ func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[Sta return &copied, nil } +// jsonClone deep-copies v via JSON marshal/unmarshal. Returns nil if v +// is nil. Panics on marshal/unmarshal failure: callers use this for +// types we control (messages, artifacts) where serialization failure +// indicates a programmer error, not a runtime condition. +func jsonClone[T any](v *T) *T { + if v == nil { + return nil + } + bytes, err := json.Marshal(v) + if err != nil { + panic(fmt.Sprintf("agent: jsonClone marshal: %v", err)) + } + var out T + if err := json.Unmarshal(bytes, &out); err != nil { + panic(fmt.Sprintf("agent: jsonClone unmarshal: %v", err)) + } + return &out +} + +// cloneArtifacts returns a deep copy of arts. Returns nil if arts is empty. +func cloneArtifacts(arts []*Artifact) []*Artifact { + if len(arts) == 0 { + return nil + } + out := make([]*Artifact, len(arts)) + for i, a := range arts { + out[i] = jsonClone(a) + } + return out +} + // --- Snapshot companion actions --- // registerSnapshotActions registers the agent's companion actions when From 669c723913c1058f553e25df900d4c436b38836a Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 13 May 2026 11:48:23 -0700 Subject: [PATCH 081/141] feat(go/exp/localstore): file-based session store; fix sample Ctrl+C - New go/ai/exp/localstore package with FileSessionStore and InMemorySessionStore (moved out of exp). Positioned for single-process / single-instance use (CLI tools, desktop apps, local services); production deployments should reach for a real database-backed store. Tests live alongside. - exp keeps a private testInMemStore fixture used only by the agent's internal tests so they don't introduce an exp -> localstore -> exp import cycle. - agent-custom, agent-prompt, and agent-inline now own SIGINT handling (signal.NotifyContext + ctx-aware select REPL with a readLines helper). genkit.Init's existing signal catch stays in place for dev-mode reflection-server cleanup; samples just need to propagate the cancellable ctx. Previously Ctrl+C was trapped but no one observed the cancellation, so the prompt spun in a tight loop. agent-inline switches to FileSessionStore as a demo. --- go/ai/exp/agent_test.go | 191 ++---------- go/ai/exp/localstore/file.go | 301 +++++++++++++++++++ go/ai/exp/localstore/file_test.go | 399 ++++++++++++++++++++++++++ go/ai/exp/localstore/inmemory.go | 213 ++++++++++++++ go/ai/exp/localstore/inmemory_test.go | 159 ++++++++++ go/ai/exp/localstore/store_test.go | 23 ++ go/ai/exp/session.go | 181 ------------ go/ai/exp/teststore_test.go | 187 ++++++++++++ go/genkit/genkit.go | 4 +- go/samples/agent-custom/main.go | 57 +++- go/samples/agent-inline/main.go | 62 +++- go/samples/agent-prompt/main.go | 56 +++- 12 files changed, 1469 insertions(+), 364 deletions(-) create mode 100644 go/ai/exp/localstore/file.go create mode 100644 go/ai/exp/localstore/file_test.go create mode 100644 go/ai/exp/localstore/inmemory.go create mode 100644 go/ai/exp/localstore/inmemory_test.go create mode 100644 go/ai/exp/localstore/store_test.go create mode 100644 go/ai/exp/teststore_test.go diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index d3c352c07d..40891911ab 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -122,7 +122,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { func TestAgent_WithSessionStore(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "snapshotFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -192,7 +192,7 @@ func TestAgent_WithSessionStore(t *testing.T) { func TestAgent_ResumeFromSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "resumeFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -415,7 +415,7 @@ func TestAgent_Artifacts(t *testing.T) { func TestAgent_SnapshotCallback(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() // Only snapshot on even turns. callbackCalls := 0 @@ -627,135 +627,6 @@ func TestAgent_SetMessages(t *testing.T) { } } -func TestInMemorySessionStore(t *testing.T) { - t.Run("GetMissing", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - snap, err := store.GetSnapshot(context.Background(), "nonexistent") - if err != nil { - t.Fatalf("GetSnapshot failed: %v", err) - } - if snap != nil { - t.Errorf("expected nil, got %v", snap) - } - }) - - t.Run("SaveWithFixedID", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - saved, err := store.SaveSnapshot(context.Background(), "snap-1", - func(existing *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - if existing != nil { - t.Errorf("expected nil existing on first save, got %+v", existing) - } - return &SessionSnapshot[testState]{ - Status: SnapshotStatusSucceeded, - State: &SessionState[testState]{Custom: testState{Counter: 1}}, - }, nil - }) - if err != nil { - t.Fatalf("SaveSnapshot failed: %v", err) - } - if saved.SnapshotID != "snap-1" { - t.Errorf("saved SnapshotID = %q, want %q", saved.SnapshotID, "snap-1") - } - if saved.CreatedAt.IsZero() || saved.UpdatedAt.IsZero() { - t.Errorf("expected CreatedAt/UpdatedAt stamped, got created=%v updated=%v", - saved.CreatedAt, saved.UpdatedAt) - } - }) - - t.Run("GetReturnsCopy", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - if _, err := store.SaveSnapshot(context.Background(), "snap-1", - func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - return &SessionSnapshot[testState]{ - Status: SnapshotStatusSucceeded, - State: &SessionState[testState]{Custom: testState{Counter: 1}}, - }, nil - }); err != nil { - t.Fatalf("SaveSnapshot: %v", err) - } - retrieved, _ := store.GetSnapshot(context.Background(), "snap-1") - retrieved.State.Custom.Counter = 999 - retrieved2, _ := store.GetSnapshot(context.Background(), "snap-1") - if retrieved2.State.Custom.Counter != 1 { - t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) - } - }) - - t.Run("DefaultsEmptyStatusToComplete", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - saved, err := store.SaveSnapshot(context.Background(), "", - func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - return &SessionSnapshot[testState]{}, nil - }) - if err != nil { - t.Fatalf("SaveSnapshot: %v", err) - } - if saved.SnapshotID == "" { - t.Error("expected store to generate SnapshotID") - } - if saved.Status != SnapshotStatusSucceeded { - t.Errorf("expected Status=complete by default, got %q", saved.Status) - } - }) - - t.Run("NoopFnSkipsWrite", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - if _, err := store.SaveSnapshot(context.Background(), "snap-1", - func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - return &SessionSnapshot[testState]{Status: SnapshotStatusSucceeded}, nil - }); err != nil { - t.Fatalf("seed: %v", err) - } - before, _ := store.GetSnapshot(context.Background(), "snap-1") - noop, err := store.SaveSnapshot(context.Background(), "snap-1", - func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - return nil, nil - }) - if err != nil { - t.Fatalf("noop SaveSnapshot: %v", err) - } - if noop != nil { - t.Errorf("expected nil return on noop, got %+v", noop) - } - after, _ := store.GetSnapshot(context.Background(), "snap-1") - if before.UpdatedAt != after.UpdatedAt { - t.Errorf("noop should not bump UpdatedAt: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) - } - }) - - t.Run("PreservesCreatedAtOnUpdate", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - saved, err := store.SaveSnapshot(context.Background(), "snap-1", - func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - return &SessionSnapshot[testState]{Status: SnapshotStatusSucceeded}, nil - }) - if err != nil { - t.Fatalf("seed: %v", err) - } - time.Sleep(time.Millisecond) // ensure measurable UpdatedAt delta - updated, err := store.SaveSnapshot(context.Background(), "snap-1", - func(existing *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { - if existing == nil { - t.Fatal("expected non-nil existing on update") - } - return &SessionSnapshot[testState]{ - Status: SnapshotStatusSucceeded, - State: &SessionState[testState]{Custom: testState{Counter: 2}}, - }, nil - }) - if err != nil { - t.Fatalf("update: %v", err) - } - if !updated.CreatedAt.Equal(saved.CreatedAt) { - t.Errorf("CreatedAt not preserved: before=%v after=%v", saved.CreatedAt, updated.CreatedAt) - } - if !updated.UpdatedAt.After(saved.UpdatedAt) { - t.Errorf("UpdatedAt did not advance: before=%v after=%v", saved.UpdatedAt, updated.UpdatedAt) - } - }) -} - func TestAgent_TurnSpanOutput(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) @@ -837,7 +708,7 @@ func TestAgent_TurnSpanOutput(t *testing.T) { func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() var capturedOutputs []any @@ -1101,7 +972,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() ai.DefinePrompt(reg, "snapPrompt", ai.WithModelName("test/echo"), @@ -1432,7 +1303,7 @@ func TestAgent_RunText_WithState(t *testing.T) { func TestAgent_RunText_WithSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "runSnapshotFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -1572,7 +1443,7 @@ func TestPromptAgent_RejectsToolResponsePart(t *testing.T) { func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "dedupFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -1616,7 +1487,7 @@ func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "multiDedupFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -1678,7 +1549,7 @@ func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "postRunMutateFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -1857,7 +1728,7 @@ func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { // Sanity: each TurnEnd chunk carries the snapshot ID of the turn-end // snapshot, and the snapshots themselves are persisted. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "turnEndSnapshotFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -1920,7 +1791,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { // - NOT write a separate turn-end snapshot for A or D (suspended). // After release, the finalized snapshot has both A's and D's effects. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() entered := make(chan struct{}, 4) release := make(chan struct{}) @@ -2019,7 +1890,7 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { // Run two normal turns first, then detach during a third (in-flight) // turn. The pending snapshot must chain off the second turn's snapshot. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() enter := make(chan struct{}, 4) release := make(chan struct{}, 4) @@ -2132,7 +2003,7 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { // Client detaches mid-flow; flow finishes naturally; pending snapshot // flips to status=succeeded with the full session state. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() release := make(chan struct{}) entered := make(chan struct{}) @@ -2227,7 +2098,7 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { // the finalized snapshot's state. The wire forward is the only thing // detach suppresses, so flow authors don't need to branch on detach. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() detached := make(chan struct{}) release := make(chan struct{}) @@ -2303,7 +2174,7 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() release := make(chan struct{}) entered := make(chan struct{}) @@ -2375,7 +2246,7 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { // subscriber notifies the runtime, which cancels the work context, and // the finalizer rewrites the snapshot with status=aborted. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() entered := make(chan struct{}) @@ -2448,7 +2319,7 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { // Sanity: a non-detached invocation against a store-backed flow still // behaves like a synchronous flow (turn-end snapshots, no pending row). reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "syncStillWorks", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -2503,7 +2374,7 @@ func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { // the regression guard for "until detach=true is called, this is a // normal HTTP/WS connection that cancels on close." reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() entered := make(chan struct{}) exited := make(chan error, 1) @@ -2555,7 +2426,7 @@ func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() erroredID := "errored-456" if _, err := store.SaveSnapshot(context.Background(), erroredID, @@ -2591,7 +2462,7 @@ func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() // Transform that scrubs a specific word from all messages. transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { @@ -2671,7 +2542,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } func TestInMemorySessionStore_GetSnapshot_NotFound(t *testing.T) { - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() snap, err := store.GetSnapshot(context.Background(), "nope") if err != nil { @@ -2711,7 +2582,7 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { // - snapshotId requires a store (server-managed state), // - state requires the absence of a store (client-managed state). ctx := context.Background() - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() state := &SessionState[testState]{Custom: testState{Counter: 1}} cases := []struct { @@ -2835,7 +2706,7 @@ func TestAgent_AgentMetadata(t *testing.T) { name: "store with full capabilities → server-managed, abortable", define: func(reg api.Registry, flowName string) { DefineCustomAgent(reg, flowName, noopFn, - WithSessionStore(NewInMemorySessionStore[testState]())) + WithSessionStore(newTestInMemStore[testState]())) }, wantMgmt: AgentStateManagementServer, wantAbortab: true, @@ -2878,7 +2749,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { // registered regardless. t.Run("aborter capability → both registered", func(t *testing.T) { reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() // implements SnapshotAborter + store := newTestInMemStore[testState]() // implements SnapshotAborter DefineCustomAgent(reg, "fullCaps", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil @@ -2927,7 +2798,7 @@ func TestAgent_AbortAction_NotFound(t *testing.T) { func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, - WithSessionStore(NewInMemorySessionStore[testState]()), + WithSessionStore(newTestInMemStore[testState]()), ) abortAction := core.ResolveActionFor[*AbortSnapshotRequest, *AbortSnapshotResponse, struct{}, struct{}]( @@ -2988,7 +2859,7 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { // End-to-end: run a flow that the client detaches from, let it // finalize, then resume from its snapshot as if reconnecting later. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "resumeDetachedFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -3053,7 +2924,7 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { ctx := context.Background() - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() // Abort on missing snapshot returns empty status, no error. if status, err := store.AbortSnapshot(ctx, "nope"); err != nil || status != "" { @@ -3130,7 +3001,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { // complete. The subscriber observes the status flip and the finalizer // reads the resulting flag. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() fnRelease := make(chan struct{}) entered := make(chan struct{}) @@ -3195,7 +3066,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() // Subscribe to a missing snapshot: channel returns immediately closed // without yielding a value. @@ -3265,7 +3136,7 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { // Calling AbortSnapshot on an already-terminal snapshot is a no-op // that returns the existing status. reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "abortNoop", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -3309,7 +3180,7 @@ func TestAgent_ResultAndOutput_IsolatedFromSession(t *testing.T) { // in depth in case fn returns AgentResult built with raw session // pointers instead of going through Result(). reg := newTestRegistry(t) - store := NewInMemorySessionStore[testState]() + store := newTestInMemStore[testState]() var ( sessionMsgAfterMutation string diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go new file mode 100644 index 0000000000..0c7a26a8dc --- /dev/null +++ b/go/ai/exp/localstore/file.go @@ -0,0 +1,301 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package localstore + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + "sync" + "time" + + "github.com/firebase/genkit/go/ai/exp" + "github.com/google/uuid" +) + +// FileSessionStore is a snapshot store that persists snapshots as JSON files +// on the local filesystem. Each snapshot is written to its own file named +// ".json" in the configured directory. +// +// The store is safe for concurrent use within a single process. It does NOT +// coordinate writes with other processes that may share the same directory: +// the only synchronization is the per-instance mutex. If multiple processes +// write to the same directory the last successful rename wins; readers may +// also observe a brief window during which a snapshot is still being written +// by another process (the rename itself is atomic, but cross-process +// linearization is not guaranteed). +// +// [FileSessionStore.OnSnapshotStatusChange] uses in-process channels and only +// reflects status transitions caused by calls on this store instance. +// External writes to the directory and writes from other processes are not +// observed. +type FileSessionStore[State any] struct { + // mu serializes the read-modify-write paths and the subscriber bookkeeping. + // File I/O happens under the lock; this matches the simplicity of + // [InMemorySessionStore] and is adequate when writes are infrequent + // (typically once per turn). + mu sync.Mutex + dir string + subs map[string][]chan exp.SnapshotStatus +} + +// NewFileSessionStore creates a file-based snapshot store rooted at dir. +// The directory is created (mode 0o700) if it does not already exist. +// Returns an error if dir is empty or cannot be created. +func NewFileSessionStore[State any](dir string) (*FileSessionStore[State], error) { + if dir == "" { + return nil, errors.New("FileSessionStore: dir is required") + } + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, fmt.Errorf("FileSessionStore: create dir %q: %w", dir, err) + } + return &FileSessionStore[State]{ + dir: dir, + subs: make(map[string][]chan exp.SnapshotStatus), + }, nil +} + +// GetSnapshot retrieves a snapshot by ID. Returns nil if not found. +func (s *FileSessionStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*exp.SessionSnapshot[State], error) { + if err := validateSnapshotID(snapshotID); err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + return s.readLocked(snapshotID) +} + +// SaveSnapshot atomically reads, applies fn, and persists. See the +// [exp.SnapshotWriter] interface for the full contract; this implementation +// satisfies it by holding s.mu for the entire read-modify-write so fn is +// called exactly once per SaveSnapshot call. +func (s *FileSessionStore[State]) SaveSnapshot( + _ context.Context, + id string, + fn func(existing *exp.SessionSnapshot[State]) (*exp.SessionSnapshot[State], error), +) (*exp.SessionSnapshot[State], error) { + if id == "" { + id = uuid.New().String() + } else if err := validateSnapshotID(id); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + existing, err := s.readLocked(id) + if err != nil { + return nil, err + } + + next, err := fn(existing) + if err != nil { + return nil, err + } + if next == nil { + return nil, nil + } + + next.SnapshotID = id + now := time.Now() + if existing != nil { + next.CreatedAt = existing.CreatedAt + } else { + next.CreatedAt = now + } + next.UpdatedAt = now + if next.Status == "" { + next.Status = exp.SnapshotStatusSucceeded + } + + if err := s.writeLocked(next); err != nil { + return nil, err + } + if existing == nil || existing.Status != next.Status { + s.notifyLocked(id, next.Status) + } + return next, nil +} + +// AbortSnapshot atomically flips a pending snapshot to aborted. If the +// snapshot is already terminal the existing status is returned unchanged. +// Returns an empty status if the snapshot is not found. +func (s *FileSessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (exp.SnapshotStatus, error) { + if err := validateSnapshotID(snapshotID); err != nil { + return "", err + } + s.mu.Lock() + defer s.mu.Unlock() + + snap, err := s.readLocked(snapshotID) + if err != nil { + return "", err + } + if snap == nil { + return "", nil + } + if snap.Status == exp.SnapshotStatusPending { + snap.Status = exp.SnapshotStatusAborted + snap.UpdatedAt = time.Now() + if err := s.writeLocked(snap); err != nil { + return "", err + } + s.notifyLocked(snapshotID, snap.Status) + } + return snap.Status, nil +} + +// OnSnapshotStatusChange subscribes to status changes for a snapshot. The +// returned channel yields the current status (if any) and any subsequent +// changes triggered by calls on this store instance, until ctx is cancelled. +// Changes made by other processes writing to the same directory are not +// observed. +func (s *FileSessionStore[State]) OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan exp.SnapshotStatus { + ch := make(chan exp.SnapshotStatus, 1) + if err := validateSnapshotID(snapshotID); err != nil { + close(ch) + return ch + } + + s.mu.Lock() + snap, err := s.readLocked(snapshotID) + if err != nil || snap == nil { + s.mu.Unlock() + close(ch) + return ch + } + ch <- snap.Status + s.subs[snapshotID] = append(s.subs[snapshotID], ch) + s.mu.Unlock() + + context.AfterFunc(ctx, func() { s.removeSub(snapshotID, ch) }) + return ch +} + +// readLocked reads and parses the snapshot file. Returns (nil, nil) if the +// file does not exist. Caller must hold s.mu. +func (s *FileSessionStore[State]) readLocked(snapshotID string) (*exp.SessionSnapshot[State], error) { + data, err := os.ReadFile(s.path(snapshotID)) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("FileSessionStore: read %s: %w", snapshotID, err) + } + var snap exp.SessionSnapshot[State] + if err := json.Unmarshal(data, &snap); err != nil { + return nil, fmt.Errorf("FileSessionStore: unmarshal %s: %w", snapshotID, err) + } + return &snap, nil +} + +// writeLocked atomically writes the snapshot to disk via a temp file + +// rename. Caller must hold s.mu. +func (s *FileSessionStore[State]) writeLocked(snap *exp.SessionSnapshot[State]) error { + data, err := json.MarshalIndent(snap, "", " ") + if err != nil { + return fmt.Errorf("FileSessionStore: marshal: %w", err) + } + f, err := os.CreateTemp(s.dir, snap.SnapshotID+".*.tmp") + if err != nil { + return fmt.Errorf("FileSessionStore: create temp: %w", err) + } + tmpName := f.Name() + // Best-effort cleanup if anything fails before the rename succeeds. + // Once renamed, the temp file no longer exists so Remove is a no-op. + defer os.Remove(tmpName) + + if _, err := f.Write(data); err != nil { + f.Close() + return fmt.Errorf("FileSessionStore: write: %w", err) + } + if err := f.Sync(); err != nil { + f.Close() + return fmt.Errorf("FileSessionStore: sync: %w", err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("FileSessionStore: close: %w", err) + } + if err := os.Rename(tmpName, s.path(snap.SnapshotID)); err != nil { + return fmt.Errorf("FileSessionStore: rename: %w", err) + } + return nil +} + +// path returns the on-disk path for a snapshot ID. The ID is assumed to have +// been validated by validateSnapshotID. +func (s *FileSessionStore[State]) path(snapshotID string) string { + return filepath.Join(s.dir, snapshotID+".json") +} + +// removeSub detaches a subscriber and closes its channel. +func (s *FileSessionStore[State]) removeSub(snapshotID string, ch chan exp.SnapshotStatus) { + s.mu.Lock() + defer s.mu.Unlock() + subs := s.subs[snapshotID] + i := slices.Index(subs, ch) + if i < 0 { + return + } + subs = slices.Delete(subs, i, i+1) + if len(subs) == 0 { + delete(s.subs, snapshotID) + } else { + s.subs[snapshotID] = subs + } + close(ch) +} + +// notifyLocked publishes status to all live subscribers of snapshotID. +// Caller must hold s.mu. Sends are best-effort: a slow subscriber may miss +// intermediate values, but the latest value visible to the subscription is +// always one of the values persisted to disk. +func (s *FileSessionStore[State]) notifyLocked(snapshotID string, status exp.SnapshotStatus) { + for _, ch := range s.subs[snapshotID] { + select { + case ch <- status: + default: + } + } +} + +// validateSnapshotID rejects IDs that would escape the store directory or +// collide with reserved filenames. UUIDs (the default produced by an empty +// id) pass trivially. +func validateSnapshotID(id string) error { + if id == "" { + return errors.New("FileSessionStore: snapshot ID is empty") + } + if strings.ContainsAny(id, `/\`) || strings.Contains(id, "..") { + return fmt.Errorf("FileSessionStore: snapshot ID %q contains path separators", id) + } + if strings.HasPrefix(id, ".") { + return fmt.Errorf("FileSessionStore: snapshot ID %q must not start with '.'", id) + } + // Disallow NUL and control characters that some filesystems reject. + for _, r := range id { + if r < 0x20 { + return fmt.Errorf("FileSessionStore: snapshot ID %q contains control characters", id) + } + } + return nil +} diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go new file mode 100644 index 0000000000..c992e4d2f5 --- /dev/null +++ b/go/ai/exp/localstore/file_test.go @@ -0,0 +1,399 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package localstore + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/firebase/genkit/go/ai/exp" +) + +func newFileStore(t *testing.T) *FileSessionStore[testState] { + t.Helper() + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + return store +} + +func TestFileSessionStore(t *testing.T) { + t.Run("EmptyDirRejected", func(t *testing.T) { + if _, err := NewFileSessionStore[testState](""); err == nil { + t.Error("expected error for empty dir, got nil") + } + }) + + t.Run("CreatesMissingDir", func(t *testing.T) { + dir := filepath.Join(t.TempDir(), "nested", "subdir") + if _, err := NewFileSessionStore[testState](dir); err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + if _, err := os.Stat(dir); err != nil { + t.Errorf("expected dir to be created, stat: %v", err) + } + }) + + t.Run("GetMissing", func(t *testing.T) { + store := newFileStore(t) + snap, err := store.GetSnapshot(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap != nil { + t.Errorf("expected nil, got %v", snap) + } + }) + + t.Run("SaveWithFixedID", func(t *testing.T) { + store := newFileStore(t) + saved, err := store.SaveSnapshot(context.Background(), "snap-1", + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + if existing != nil { + t.Errorf("expected nil existing on first save, got %+v", existing) + } + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot failed: %v", err) + } + if saved.SnapshotID != "snap-1" { + t.Errorf("saved SnapshotID = %q, want %q", saved.SnapshotID, "snap-1") + } + if saved.CreatedAt.IsZero() || saved.UpdatedAt.IsZero() { + t.Errorf("expected CreatedAt/UpdatedAt stamped, got created=%v updated=%v", + saved.CreatedAt, saved.UpdatedAt) + } + }) + + t.Run("SaveWithEmptyIDGeneratesUUID", func(t *testing.T) { + store := newFileStore(t) + saved, err := store.SaveSnapshot(context.Background(), "", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + if saved.SnapshotID == "" { + t.Error("expected store to generate SnapshotID") + } + }) + + t.Run("GetReturnsCopy", func(t *testing.T) { + store := newFileStore(t) + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + retrieved, _ := store.GetSnapshot(context.Background(), "snap-1") + retrieved.State.Custom.Counter = 999 + retrieved2, _ := store.GetSnapshot(context.Background(), "snap-1") + if retrieved2.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) + } + }) + + t.Run("DefaultsEmptyStatusToSucceeded", func(t *testing.T) { + store := newFileStore(t) + saved, err := store.SaveSnapshot(context.Background(), "", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{}, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + if saved.Status != exp.SnapshotStatusSucceeded { + t.Errorf("expected Status=succeeded by default, got %q", saved.Status) + } + }) + + t.Run("NoopFnSkipsWrite", func(t *testing.T) { + store := newFileStore(t) + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }); err != nil { + t.Fatalf("seed: %v", err) + } + before, _ := store.GetSnapshot(context.Background(), "snap-1") + noop, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return nil, nil + }) + if err != nil { + t.Fatalf("noop SaveSnapshot: %v", err) + } + if noop != nil { + t.Errorf("expected nil return on noop, got %+v", noop) + } + after, _ := store.GetSnapshot(context.Background(), "snap-1") + if !before.UpdatedAt.Equal(after.UpdatedAt) { + t.Errorf("noop should not bump UpdatedAt: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) + } + }) + + t.Run("PreservesCreatedAtOnUpdate", func(t *testing.T) { + store := newFileStore(t) + saved, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }) + if err != nil { + t.Fatalf("seed: %v", err) + } + time.Sleep(time.Millisecond) + updated, err := store.SaveSnapshot(context.Background(), "snap-1", + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + if existing == nil { + t.Fatal("expected non-nil existing on update") + } + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, + }, nil + }) + if err != nil { + t.Fatalf("update: %v", err) + } + if !updated.CreatedAt.Equal(saved.CreatedAt) { + t.Errorf("CreatedAt not preserved: before=%v after=%v", saved.CreatedAt, updated.CreatedAt) + } + if !updated.UpdatedAt.After(saved.UpdatedAt) { + t.Errorf("UpdatedAt did not advance: before=%v after=%v", saved.UpdatedAt, updated.UpdatedAt) + } + }) + + t.Run("AbortPendingFlipsToAborted", func(t *testing.T) { + store := newFileStore(t) + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusPending}, nil + }); err != nil { + t.Fatalf("seed: %v", err) + } + status, err := store.AbortSnapshot(context.Background(), "snap-1") + if err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + if status != exp.SnapshotStatusAborted { + t.Errorf("status = %q, want %q", status, exp.SnapshotStatusAborted) + } + snap, _ := store.GetSnapshot(context.Background(), "snap-1") + if snap.Status != exp.SnapshotStatusAborted { + t.Errorf("persisted status = %q, want %q", snap.Status, exp.SnapshotStatusAborted) + } + }) + + t.Run("AbortTerminalIsNoop", func(t *testing.T) { + store := newFileStore(t) + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }); err != nil { + t.Fatalf("seed: %v", err) + } + status, err := store.AbortSnapshot(context.Background(), "snap-1") + if err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + if status != exp.SnapshotStatusSucceeded { + t.Errorf("status = %q, want %q (no-op on terminal)", status, exp.SnapshotStatusSucceeded) + } + }) + + t.Run("AbortMissingReturnsEmpty", func(t *testing.T) { + store := newFileStore(t) + status, err := store.AbortSnapshot(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + if status != "" { + t.Errorf("status = %q, want empty (not found)", status) + } + }) + + t.Run("StatusSubscriptionYieldsCurrentAndChanges", func(t *testing.T) { + store := newFileStore(t) + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusPending}, nil + }); err != nil { + t.Fatalf("seed: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := store.OnSnapshotStatusChange(ctx, "snap-1") + + select { + case s := <-ch: + if s != exp.SnapshotStatusPending { + t.Errorf("initial status = %q, want %q", s, exp.SnapshotStatusPending) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for initial status") + } + + if _, err := store.AbortSnapshot(context.Background(), "snap-1"); err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + select { + case s := <-ch: + if s != exp.SnapshotStatusAborted { + t.Errorf("post-abort status = %q, want %q", s, exp.SnapshotStatusAborted) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for aborted status") + } + }) + + t.Run("StatusSubscriptionOnMissingIsClosed", func(t *testing.T) { + store := newFileStore(t) + ch := store.OnSnapshotStatusChange(context.Background(), "nonexistent") + select { + case _, ok := <-ch: + if ok { + t.Error("expected closed channel for missing snapshot") + } + case <-time.After(time.Second): + t.Fatal("timeout waiting on closed channel") + } + }) + + t.Run("StatusSubscriptionClosesOnCtxCancel", func(t *testing.T) { + store := newFileStore(t) + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusPending}, nil + }); err != nil { + t.Fatalf("seed: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + ch := store.OnSnapshotStatusChange(ctx, "snap-1") + <-ch // drain initial + cancel() + select { + case _, ok := <-ch: + if ok { + select { + case _, ok2 := <-ch: + if ok2 { + t.Error("expected channel closed after ctx cancel") + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for channel close") + } + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for channel close") + } + }) + + t.Run("PersistsAcrossStoreInstances", func(t *testing.T) { + dir := t.TempDir() + store1, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + if _, err := store1.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 42}}, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + + store2, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + got, err := store2.GetSnapshot(context.Background(), "snap-1") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if got == nil { + t.Fatal("expected snapshot to persist across store instances") + } + if got.State.Custom.Counter != 42 { + t.Errorf("counter = %d, want 42", got.State.Custom.Counter) + } + }) + + t.Run("InvalidIDRejected", func(t *testing.T) { + store := newFileStore(t) + cases := []string{ + "../escape", + "a/b", + `a\b`, + ".hidden", + "foo..bar", + } + for _, id := range cases { + t.Run(id, func(t *testing.T) { + if _, err := store.GetSnapshot(context.Background(), id); err == nil { + t.Errorf("GetSnapshot(%q): expected error, got nil", id) + } + if _, err := store.SaveSnapshot(context.Background(), id, + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{}, nil + }); err == nil { + t.Errorf("SaveSnapshot(%q): expected error, got nil", id) + } + if _, err := store.AbortSnapshot(context.Background(), id); err == nil { + t.Errorf("AbortSnapshot(%q): expected error, got nil", id) + } + }) + } + }) + + t.Run("FileWrittenOnDisk", func(t *testing.T) { + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "snap-1.json")); err != nil { + t.Errorf("expected snap-1.json on disk: %v", err) + } + }) + + t.Run("ImplementsSessionStoreAndAborter", func(t *testing.T) { + var _ exp.SessionStore[testState] = (*FileSessionStore[testState])(nil) + var _ exp.SnapshotAborter = (*FileSessionStore[testState])(nil) + }) +} diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go new file mode 100644 index 0000000000..a50f47d13f --- /dev/null +++ b/go/ai/exp/localstore/inmemory.go @@ -0,0 +1,213 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package localstore provides single-process [exp.SessionStore] implementations +// suitable for local development, tests, and single-instance apps (CLI tools, +// desktop apps, local web services). For multi-instance production deployments +// use a real database-backed store. +package localstore + +import ( + "context" + "encoding/json" + "fmt" + "slices" + "sync" + "time" + + "github.com/firebase/genkit/go/ai/exp" + "github.com/google/uuid" +) + +// InMemorySessionStore provides a thread-safe in-memory snapshot store. State +// is lost when the process exits; use [FileSessionStore] or a real backend +// when persistence is needed. +// +// It implements [exp.SessionStore] and [exp.SnapshotAborter]. +type InMemorySessionStore[State any] struct { + // mu is RWMutex so GetSnapshot (which JSON-marshals while holding the + // lock) can run concurrently with other readers. All writers (Save, + // Abort, OnSnapshotStatusChange, removeSub) take the full Lock(). + mu sync.RWMutex + snapshots map[string]*exp.SessionSnapshot[State] + subs map[string][]chan exp.SnapshotStatus +} + +// NewInMemorySessionStore creates a new in-memory snapshot store. +func NewInMemorySessionStore[State any]() *InMemorySessionStore[State] { + return &InMemorySessionStore[State]{ + snapshots: make(map[string]*exp.SessionSnapshot[State]), + subs: make(map[string][]chan exp.SnapshotStatus), + } +} + +// GetSnapshot retrieves a snapshot by ID. Returns nil if not found. +func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*exp.SessionSnapshot[State], error) { + s.mu.RLock() + defer s.mu.RUnlock() + snap, ok := s.snapshots[snapshotID] + if !ok { + return nil, nil + } + return copySnapshot(snap) +} + +// AbortSnapshot atomically flips a pending snapshot to aborted. If the +// snapshot is already terminal the existing status is returned unchanged. +// Returns an empty status if the snapshot is not found. +func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (exp.SnapshotStatus, error) { + s.mu.Lock() + defer s.mu.Unlock() + snap, ok := s.snapshots[snapshotID] + if !ok { + return "", nil + } + if snap.Status == exp.SnapshotStatusPending { + snap.Status = exp.SnapshotStatusAborted + snap.UpdatedAt = time.Now() + s.notifyLocked(snapshotID, snap.Status) + } + return snap.Status, nil +} + +// SaveSnapshot atomically reads, applies fn, and persists. See the +// [exp.SnapshotWriter] interface for the full contract; this implementation +// satisfies it by holding s.mu for the entire read-modify-write so fn is +// called exactly once per SaveSnapshot call. +func (s *InMemorySessionStore[State]) SaveSnapshot( + _ context.Context, + id string, + fn func(existing *exp.SessionSnapshot[State]) (*exp.SessionSnapshot[State], error), +) (*exp.SessionSnapshot[State], error) { + s.mu.Lock() + defer s.mu.Unlock() + + if id == "" { + id = uuid.New().String() + } + + var existing *exp.SessionSnapshot[State] + if stored, ok := s.snapshots[id]; ok { + copied, err := copySnapshot(stored) + if err != nil { + return nil, err + } + existing = copied + } + + next, err := fn(existing) + if err != nil { + return nil, err + } + if next == nil { + return nil, nil + } + + next.SnapshotID = id + now := time.Now() + if existing != nil { + next.CreatedAt = existing.CreatedAt + } else { + next.CreatedAt = now + } + next.UpdatedAt = now + if next.Status == "" { + next.Status = exp.SnapshotStatusSucceeded + } + + copied, err := copySnapshot(next) + if err != nil { + return nil, err + } + s.snapshots[id] = copied + if existing == nil || existing.Status != next.Status { + s.notifyLocked(id, next.Status) + } + // Return next (the freshly-allocated struct from fn) rather than + // copied: copied is the pointer the store retains, so returning it + // would alias the caller's view with the stored row and let future + // in-place mutations (e.g. AbortSnapshot updating UpdatedAt) leak + // through. + return next, nil +} + +// OnSnapshotStatusChange subscribes to status changes for a snapshot. The +// returned channel yields the current status (if any) and any subsequent +// changes, until ctx is cancelled. +func (s *InMemorySessionStore[State]) OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan exp.SnapshotStatus { + ch := make(chan exp.SnapshotStatus, 1) + + s.mu.Lock() + snap, ok := s.snapshots[snapshotID] + if !ok { + s.mu.Unlock() + close(ch) + return ch + } + ch <- snap.Status + s.subs[snapshotID] = append(s.subs[snapshotID], ch) + s.mu.Unlock() + + context.AfterFunc(ctx, func() { s.removeSub(snapshotID, ch) }) + return ch +} + +// removeSub detaches a subscriber and closes its channel. +func (s *InMemorySessionStore[State]) removeSub(snapshotID string, ch chan exp.SnapshotStatus) { + s.mu.Lock() + defer s.mu.Unlock() + subs := s.subs[snapshotID] + i := slices.Index(subs, ch) + if i < 0 { + return + } + subs = slices.Delete(subs, i, i+1) + if len(subs) == 0 { + delete(s.subs, snapshotID) + } else { + s.subs[snapshotID] = subs + } + close(ch) +} + +// notifyLocked publishes status to all live subscribers of snapshotID. +// Caller must hold s.mu. Sends are best-effort: a slow subscriber may miss +// intermediate values, but the store guarantees the latest value visible +// to the subscription is the one persisted at notify time. +func (s *InMemorySessionStore[State]) notifyLocked(snapshotID string, status exp.SnapshotStatus) { + for _, ch := range s.subs[snapshotID] { + select { + case ch <- status: + default: + } + } +} + +// copySnapshot creates a deep copy of a snapshot using JSON marshaling. +func copySnapshot[State any](snap *exp.SessionSnapshot[State]) (*exp.SessionSnapshot[State], error) { + if snap == nil { + return nil, nil + } + bytes, err := json.Marshal(snap) + if err != nil { + return nil, fmt.Errorf("copy snapshot: marshal: %w", err) + } + var copied exp.SessionSnapshot[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + return nil, fmt.Errorf("copy snapshot: unmarshal: %w", err) + } + return &copied, nil +} diff --git a/go/ai/exp/localstore/inmemory_test.go b/go/ai/exp/localstore/inmemory_test.go new file mode 100644 index 0000000000..0f8c5c8fc3 --- /dev/null +++ b/go/ai/exp/localstore/inmemory_test.go @@ -0,0 +1,159 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package localstore + +import ( + "context" + "testing" + "time" + + "github.com/firebase/genkit/go/ai/exp" +) + +func TestInMemorySessionStore(t *testing.T) { + t.Run("GetMissing", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + snap, err := store.GetSnapshot(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap != nil { + t.Errorf("expected nil, got %v", snap) + } + }) + + t.Run("SaveWithFixedID", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + saved, err := store.SaveSnapshot(context.Background(), "snap-1", + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + if existing != nil { + t.Errorf("expected nil existing on first save, got %+v", existing) + } + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot failed: %v", err) + } + if saved.SnapshotID != "snap-1" { + t.Errorf("saved SnapshotID = %q, want %q", saved.SnapshotID, "snap-1") + } + if saved.CreatedAt.IsZero() || saved.UpdatedAt.IsZero() { + t.Errorf("expected CreatedAt/UpdatedAt stamped, got created=%v updated=%v", + saved.CreatedAt, saved.UpdatedAt) + } + }) + + t.Run("GetReturnsCopy", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + retrieved, _ := store.GetSnapshot(context.Background(), "snap-1") + retrieved.State.Custom.Counter = 999 + retrieved2, _ := store.GetSnapshot(context.Background(), "snap-1") + if retrieved2.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) + } + }) + + t.Run("DefaultsEmptyStatusToComplete", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + saved, err := store.SaveSnapshot(context.Background(), "", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{}, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + if saved.SnapshotID == "" { + t.Error("expected store to generate SnapshotID") + } + if saved.Status != exp.SnapshotStatusSucceeded { + t.Errorf("expected Status=complete by default, got %q", saved.Status) + } + }) + + t.Run("NoopFnSkipsWrite", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + if _, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }); err != nil { + t.Fatalf("seed: %v", err) + } + before, _ := store.GetSnapshot(context.Background(), "snap-1") + noop, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return nil, nil + }) + if err != nil { + t.Fatalf("noop SaveSnapshot: %v", err) + } + if noop != nil { + t.Errorf("expected nil return on noop, got %+v", noop) + } + after, _ := store.GetSnapshot(context.Background(), "snap-1") + if before.UpdatedAt != after.UpdatedAt { + t.Errorf("noop should not bump UpdatedAt: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) + } + }) + + t.Run("PreservesCreatedAtOnUpdate", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + saved, err := store.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }) + if err != nil { + t.Fatalf("seed: %v", err) + } + time.Sleep(time.Millisecond) // ensure measurable UpdatedAt delta + updated, err := store.SaveSnapshot(context.Background(), "snap-1", + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + if existing == nil { + t.Fatal("expected non-nil existing on update") + } + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, + }, nil + }) + if err != nil { + t.Fatalf("update: %v", err) + } + if !updated.CreatedAt.Equal(saved.CreatedAt) { + t.Errorf("CreatedAt not preserved: before=%v after=%v", saved.CreatedAt, updated.CreatedAt) + } + if !updated.UpdatedAt.After(saved.UpdatedAt) { + t.Errorf("UpdatedAt did not advance: before=%v after=%v", saved.UpdatedAt, updated.UpdatedAt) + } + }) + + t.Run("ImplementsSessionStoreAndAborter", func(t *testing.T) { + var _ exp.SessionStore[testState] = (*InMemorySessionStore[testState])(nil) + var _ exp.SnapshotAborter = (*InMemorySessionStore[testState])(nil) + }) +} diff --git a/go/ai/exp/localstore/store_test.go b/go/ai/exp/localstore/store_test.go new file mode 100644 index 0000000000..c009e3be8f --- /dev/null +++ b/go/ai/exp/localstore/store_test.go @@ -0,0 +1,23 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package localstore + +// testState is the custom-state type used by store unit tests. +type testState struct { + Counter int `json:"counter"` + Topics []string `json:"topics,omitempty"` +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 2870ee4a59..48ebdcf280 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -20,15 +20,12 @@ import ( "context" "encoding/json" "fmt" - "slices" "sync" - "time" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/base" - "github.com/google/uuid" ) // --- Snapshot --- @@ -156,184 +153,6 @@ type SessionStore[State any] interface { SnapshotWriter[State] } -// InMemorySessionStore provides a thread-safe in-memory snapshot store. It -// implements the full set of optional store interfaces (reader, writer, -// aborter, status subscriber). -type InMemorySessionStore[State any] struct { - // mu is RWMutex so GetSnapshot (which JSON-marshals while holding the - // lock) can run concurrently with other readers. All writers (Save, - // Abort, OnSnapshotStatusChange, removeSub) take the full Lock(). - mu sync.RWMutex - snapshots map[string]*SessionSnapshot[State] - subs map[string][]chan SnapshotStatus -} - -// NewInMemorySessionStore creates a new in-memory snapshot store. -func NewInMemorySessionStore[State any]() *InMemorySessionStore[State] { - return &InMemorySessionStore[State]{ - snapshots: make(map[string]*SessionSnapshot[State]), - subs: make(map[string][]chan SnapshotStatus), - } -} - -// GetSnapshot retrieves a snapshot by ID. Returns nil if not found. -func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { - s.mu.RLock() - defer s.mu.RUnlock() - snap, ok := s.snapshots[snapshotID] - if !ok { - return nil, nil - } - return copySnapshot(snap) -} - -// AbortSnapshot atomically flips a pending snapshot to aborted. If the -// snapshot is already terminal the existing status is returned unchanged. -// Returns an empty status if the snapshot is not found. -func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (SnapshotStatus, error) { - s.mu.Lock() - defer s.mu.Unlock() - snap, ok := s.snapshots[snapshotID] - if !ok { - return "", nil - } - if snap.Status == SnapshotStatusPending { - snap.Status = SnapshotStatusAborted - snap.UpdatedAt = time.Now() - s.notifyLocked(snapshotID, snap.Status) - } - return snap.Status, nil -} - -// SaveSnapshot atomically reads, applies fn, and persists. See the -// [SnapshotWriter] interface for the full contract; this implementation -// satisfies it by holding s.mu for the entire read-modify-write so fn -// is called exactly once per SaveSnapshot call. -func (s *InMemorySessionStore[State]) SaveSnapshot( - _ context.Context, - id string, - fn func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error), -) (*SessionSnapshot[State], error) { - s.mu.Lock() - defer s.mu.Unlock() - - if id == "" { - id = uuid.New().String() - } - - var existing *SessionSnapshot[State] - if stored, ok := s.snapshots[id]; ok { - copied, err := copySnapshot(stored) - if err != nil { - return nil, err - } - existing = copied - } - - next, err := fn(existing) - if err != nil { - return nil, err - } - if next == nil { - return nil, nil - } - - next.SnapshotID = id - now := time.Now() - if existing != nil { - next.CreatedAt = existing.CreatedAt - } else { - next.CreatedAt = now - } - next.UpdatedAt = now - if next.Status == "" { - next.Status = SnapshotStatusSucceeded - } - - copied, err := copySnapshot(next) - if err != nil { - return nil, err - } - s.snapshots[id] = copied - if existing == nil || existing.Status != next.Status { - s.notifyLocked(id, next.Status) - } - // Return next (the freshly-allocated struct from fn) rather than - // copied: copied is the pointer the store retains, so returning it - // would alias the caller's view with the stored row and let future - // in-place mutations (e.g. AbortSnapshot updating UpdatedAt) leak - // through. - return next, nil -} - -// OnSnapshotStatusChange subscribes to status changes for a snapshot. The -// returned channel yields the current status (if any) and any subsequent -// changes, until ctx is cancelled. -func (s *InMemorySessionStore[State]) OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan SnapshotStatus { - ch := make(chan SnapshotStatus, 1) - - s.mu.Lock() - snap, ok := s.snapshots[snapshotID] - if !ok { - s.mu.Unlock() - close(ch) - return ch - } - ch <- snap.Status - s.subs[snapshotID] = append(s.subs[snapshotID], ch) - s.mu.Unlock() - - context.AfterFunc(ctx, func() { s.removeSub(snapshotID, ch) }) - return ch -} - -// removeSub detaches a subscriber and closes its channel. -func (s *InMemorySessionStore[State]) removeSub(snapshotID string, ch chan SnapshotStatus) { - s.mu.Lock() - defer s.mu.Unlock() - subs := s.subs[snapshotID] - i := slices.Index(subs, ch) - if i < 0 { - return - } - subs = slices.Delete(subs, i, i+1) - if len(subs) == 0 { - delete(s.subs, snapshotID) - } else { - s.subs[snapshotID] = subs - } - close(ch) -} - -// notifyLocked publishes status to all live subscribers of snapshotID. -// Caller must hold s.mu. Sends are best-effort: a slow subscriber may miss -// intermediate values, but the store guarantees the latest value visible -// to the subscription is the one persisted at notify time. -func (s *InMemorySessionStore[State]) notifyLocked(snapshotID string, status SnapshotStatus) { - for _, ch := range s.subs[snapshotID] { - select { - case ch <- status: - default: - } - } -} - -// copySnapshot creates a deep copy of a snapshot using JSON marshaling. -func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[State], error) { - if snap == nil { - return nil, nil - } - bytes, err := json.Marshal(snap) - if err != nil { - return nil, fmt.Errorf("copy snapshot: marshal: %w", err) - } - var copied SessionSnapshot[State] - if err := json.Unmarshal(bytes, &copied); err != nil { - return nil, fmt.Errorf("copy snapshot: unmarshal: %w", err) - } - return &copied, nil -} - // jsonClone deep-copies v via JSON marshal/unmarshal. Returns nil if v // is nil. Panics on marshal/unmarshal failure: callers use this for // types we control (messages, artifacts) where serialization failure diff --git a/go/ai/exp/teststore_test.go b/go/ai/exp/teststore_test.go new file mode 100644 index 0000000000..9778b970eb --- /dev/null +++ b/go/ai/exp/teststore_test.go @@ -0,0 +1,187 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +// This file is a private session store fixture used only by the agent's +// internal tests (which need access to unexported package symbols and so +// must remain in [package exp]). The production in-memory and file stores +// live in [github.com/firebase/genkit/go/ai/exp/localstore]; importing that +// package here would create an import cycle, since localstore depends on +// exp. + +import ( + "context" + "encoding/json" + "fmt" + "slices" + "sync" + "time" + + "github.com/google/uuid" +) + +// testInMemStore is a thread-safe in-memory snapshot store. Its semantics +// mirror localstore.InMemorySessionStore so the agent's internal tests +// exercise the same store behavior that production users see. +type testInMemStore[State any] struct { + mu sync.RWMutex + snapshots map[string]*SessionSnapshot[State] + subs map[string][]chan SnapshotStatus +} + +func newTestInMemStore[State any]() *testInMemStore[State] { + return &testInMemStore[State]{ + snapshots: make(map[string]*SessionSnapshot[State]), + subs: make(map[string][]chan SnapshotStatus), + } +} + +func (s *testInMemStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { + s.mu.RLock() + defer s.mu.RUnlock() + snap, ok := s.snapshots[snapshotID] + if !ok { + return nil, nil + } + return testCopySnapshot(snap) +} + +func (s *testInMemStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (SnapshotStatus, error) { + s.mu.Lock() + defer s.mu.Unlock() + snap, ok := s.snapshots[snapshotID] + if !ok { + return "", nil + } + if snap.Status == SnapshotStatusPending { + snap.Status = SnapshotStatusAborted + snap.UpdatedAt = time.Now() + s.notifyLocked(snapshotID, snap.Status) + } + return snap.Status, nil +} + +func (s *testInMemStore[State]) SaveSnapshot( + _ context.Context, + id string, + fn func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error), +) (*SessionSnapshot[State], error) { + s.mu.Lock() + defer s.mu.Unlock() + + if id == "" { + id = uuid.New().String() + } + + var existing *SessionSnapshot[State] + if stored, ok := s.snapshots[id]; ok { + copied, err := testCopySnapshot(stored) + if err != nil { + return nil, err + } + existing = copied + } + + next, err := fn(existing) + if err != nil { + return nil, err + } + if next == nil { + return nil, nil + } + + next.SnapshotID = id + now := time.Now() + if existing != nil { + next.CreatedAt = existing.CreatedAt + } else { + next.CreatedAt = now + } + next.UpdatedAt = now + if next.Status == "" { + next.Status = SnapshotStatusSucceeded + } + + copied, err := testCopySnapshot(next) + if err != nil { + return nil, err + } + s.snapshots[id] = copied + if existing == nil || existing.Status != next.Status { + s.notifyLocked(id, next.Status) + } + return next, nil +} + +func (s *testInMemStore[State]) OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan SnapshotStatus { + ch := make(chan SnapshotStatus, 1) + + s.mu.Lock() + snap, ok := s.snapshots[snapshotID] + if !ok { + s.mu.Unlock() + close(ch) + return ch + } + ch <- snap.Status + s.subs[snapshotID] = append(s.subs[snapshotID], ch) + s.mu.Unlock() + + context.AfterFunc(ctx, func() { s.removeSub(snapshotID, ch) }) + return ch +} + +func (s *testInMemStore[State]) removeSub(snapshotID string, ch chan SnapshotStatus) { + s.mu.Lock() + defer s.mu.Unlock() + subs := s.subs[snapshotID] + i := slices.Index(subs, ch) + if i < 0 { + return + } + subs = slices.Delete(subs, i, i+1) + if len(subs) == 0 { + delete(s.subs, snapshotID) + } else { + s.subs[snapshotID] = subs + } + close(ch) +} + +func (s *testInMemStore[State]) notifyLocked(snapshotID string, status SnapshotStatus) { + for _, ch := range s.subs[snapshotID] { + select { + case ch <- status: + default: + } + } +} + +func testCopySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + if snap == nil { + return nil, nil + } + bytes, err := json.Marshal(snap) + if err != nil { + return nil, fmt.Errorf("copy snapshot: marshal: %w", err) + } + var copied SessionSnapshot[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + return nil, fmt.Errorf("copy snapshot: unmarshal: %w", err) + } + return &copied, nil +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index e4c30981b9..a7656e518f 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -470,7 +470,7 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn // ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithSystem("You are a helpful assistant."), // ), -// aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), +// aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), // ) // // Example (existing prompt): @@ -481,7 +481,7 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn // // chatAgent := genkit.DefineAgent(g, "chat", // aix.FromPrompt(ChatInput{Personality: "a sarcastic pirate"}), -// aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), +// aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), // ) func DefineAgent[State any]( g *Genkit, diff --git a/go/samples/agent-custom/main.go b/go/samples/agent-custom/main.go index cc5488023b..a5b5ef9d0a 100644 --- a/go/samples/agent-custom/main.go +++ b/go/samples/agent-custom/main.go @@ -26,17 +26,21 @@ import ( "context" "fmt" "os" + "os/signal" "strings" + "syscall" "github.com/firebase/genkit/go/ai" aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" "google.golang.org/genai" ) func main() { - ctx := context.Background() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) chatAgent := genkit.DefineCustomAgent(g, "chat", @@ -67,11 +71,11 @@ func main() { } return sess.Result(), nil }, - aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), + aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) - fmt.Println("Agent Chat (type 'quit' to exit)") + fmt.Println("Agent Chat (type 'quit' to exit, Ctrl+C to abort)") fmt.Println() conn, err := chatAgent.StreamBidi(ctx) @@ -79,12 +83,25 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } + defer conn.Close() - reader := bufio.NewReader(os.Stdin) + inputCh := readLines(ctx) + +repl: for { fmt.Print("> ") - input, _ := reader.ReadString('\n') - input = strings.TrimSpace(input) + var input string + select { + case <-ctx.Done(): + fmt.Println() + break repl + case line, ok := <-inputCh: + if !ok { + fmt.Println() + break repl + } + input = strings.TrimSpace(line) + } if input == "quit" || input == "exit" { break @@ -119,6 +136,32 @@ func main() { } } - conn.Close() fmt.Println(conn.Output()) } + +// readLines reads lines from stdin on a background goroutine and yields +// them via the returned channel. The channel is closed on EOF, read error, +// or ctx cancellation. The goroutine cannot interrupt a blocked stdin +// read; on ctx cancellation it exits as soon as a line completes (or, in +// practice, when the process terminates). +func readLines(ctx context.Context) <-chan string { + ch := make(chan string) + go func() { + defer close(ch) + reader := bufio.NewReader(os.Stdin) + for { + line, err := reader.ReadString('\n') + if line != "" { + select { + case ch <- line: + case <-ctx.Done(): + return + } + } + if err != nil { + return + } + } + }() + return ch +} diff --git a/go/samples/agent-inline/main.go b/go/samples/agent-inline/main.go index 12ccca30a8..12000a2d74 100644 --- a/go/samples/agent-inline/main.go +++ b/go/samples/agent-inline/main.go @@ -26,19 +26,29 @@ import ( "context" "fmt" "os" + "os/signal" "strings" + "syscall" "github.com/firebase/genkit/go/ai" aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" "google.golang.org/genai" ) func main() { - ctx := context.Background() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + store, err := localstore.NewFileSessionStore[any]("./sessions") + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + chatAgent := genkit.DefineAgent(g, "chat", aix.FromInline( ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ @@ -48,11 +58,11 @@ func main() { })), ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), ), - aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), + aix.WithSessionStore(store), aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) - fmt.Println("Agent Chat (type 'quit' to exit)") + fmt.Println("Agent Chat (type 'quit' to exit, Ctrl+C to abort)") fmt.Println() conn, err := chatAgent.StreamBidi(ctx) @@ -60,12 +70,25 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } + defer conn.Close() + + inputCh := readLines(ctx) - reader := bufio.NewReader(os.Stdin) +repl: for { fmt.Print("> ") - input, _ := reader.ReadString('\n') - input = strings.TrimSpace(input) + var input string + select { + case <-ctx.Done(): + fmt.Println() + break repl + case line, ok := <-inputCh: + if !ok { + fmt.Println() + break repl + } + input = strings.TrimSpace(line) + } if input == "quit" || input == "exit" { break @@ -99,6 +122,31 @@ func main() { } } } +} - conn.Close() +// readLines reads lines from stdin on a background goroutine and yields +// them via the returned channel. The channel is closed on EOF, read error, +// or ctx cancellation. The goroutine cannot interrupt a blocked stdin +// read; on ctx cancellation it exits as soon as a line completes (or, in +// practice, when the process terminates). +func readLines(ctx context.Context) <-chan string { + ch := make(chan string) + go func() { + defer close(ch) + reader := bufio.NewReader(os.Stdin) + for { + line, err := reader.ReadString('\n') + if line != "" { + select { + case ch <- line: + case <-ctx.Done(): + return + } + } + if err != nil { + return + } + } + }() + return ch } diff --git a/go/samples/agent-prompt/main.go b/go/samples/agent-prompt/main.go index e271eb17f0..af29470c9a 100644 --- a/go/samples/agent-prompt/main.go +++ b/go/samples/agent-prompt/main.go @@ -26,9 +26,12 @@ import ( "context" "fmt" "os" + "os/signal" "strings" + "syscall" aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" ) @@ -38,20 +41,21 @@ type ChatPromptInput struct { } func main() { - ctx := context.Background() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) genkit.DefineSchemaFor[ChatPromptInput](g) chatAgent := genkit.DefineAgent(g, "chat", aix.FromPrompt(ChatPromptInput{Personality: "a sarcastic pirate"}), - aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), + aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 }), ) - fmt.Println("Agent Chat (type 'quit' to exit)") + fmt.Println("Agent Chat (type 'quit' to exit, Ctrl+C to abort)") fmt.Println() conn, err := chatAgent.StreamBidi(ctx) @@ -59,12 +63,25 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } + defer conn.Close() - reader := bufio.NewReader(os.Stdin) + inputCh := readLines(ctx) + +repl: for { fmt.Print("> ") - input, _ := reader.ReadString('\n') - input = strings.TrimSpace(input) + var input string + select { + case <-ctx.Done(): + fmt.Println() + break repl + case line, ok := <-inputCh: + if !ok { + fmt.Println() + break repl + } + input = strings.TrimSpace(line) + } if input == "quit" || input == "exit" { break @@ -98,6 +115,31 @@ func main() { } } } +} - conn.Close() +// readLines reads lines from stdin on a background goroutine and yields +// them via the returned channel. The channel is closed on EOF, read error, +// or ctx cancellation. The goroutine cannot interrupt a blocked stdin +// read; on ctx cancellation it exits as soon as a line completes (or, in +// practice, when the process terminates). +func readLines(ctx context.Context) <-chan string { + ch := make(chan string) + go func() { + defer close(ch) + reader := bufio.NewReader(os.Stdin) + for { + line, err := reader.ReadString('\n') + if line != "" { + select { + case ch <- line: + case <-ctx.Done(): + return + } + } + if err != nil { + return + } + } + }() + return ch } From 66f4eb810cb37aa7f22060728c97308d2871613e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 14 May 2026 09:15:36 -0700 Subject: [PATCH 082/141] refactor(go/exp): AgentConnection.Output() implicitly closes and drains Output() is the single "I'm done" call: it implicitly closes the input side and drains any chunks the caller did not consume via Receive, then blocks until the agent finalizes. The drain is required because the underlying stream buffer is shallow; without it a producing fn would block on chunk sends and never reach completion. Close remains public and idempotent for defer-style cleanup and fire-and-forget patterns. Simplifies Agent.Run accordingly: drop the explicit Close and Receive drain loop, lean on Output for both. Updates the three agent samples to the new pattern: drop defer Close, print the final snapshot ID and a friendly goodbye on any user-initiated exit, and treat context.Canceled as a clean exit rather than a fatal error. --- go/ai/exp/agent.go | 57 +++++++++++++++++++-------------- go/samples/agent-custom/main.go | 20 ++++++++++-- go/samples/agent-inline/main.go | 12 ++++++- go/samples/agent-prompt/main.go | 20 ++++++++++-- 4 files changed, 79 insertions(+), 30 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index e1c56e8684..5ef736befd 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -1322,30 +1322,16 @@ func (a *Agent[Stream, State]) Run( return nil, err } // If the bidi function fails fast (e.g. resuming from an errored - // snapshot rejects in newAgentRuntime), Send / Close / Receive - // see a closed connection and return generic "action has completed" - // errors. The real fn error is on Output(). Prefer it whenever it's - // non-nil so callers get the meaningful failure. + // snapshot rejects in newAgentRuntime), Send sees a closed connection + // and returns a generic "action has completed" error. The real fn + // error is on Output(). Prefer it whenever it's non-nil so callers + // get the meaningful failure. if err := conn.Send(input); err != nil { if _, outErr := conn.Output(); outErr != nil { return nil, outErr } return nil, err } - if err := conn.Close(); err != nil { - if _, outErr := conn.Output(); outErr != nil { - return nil, outErr - } - return nil, err - } - for _, err := range conn.Receive() { - if err != nil { - if _, outErr := conn.Output(); outErr != nil { - return nil, outErr - } - return nil, err - } - } return conn.Output() } @@ -1444,15 +1430,38 @@ func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[S return c.conn.Receive() } -// Output returns the final response after the agent completes. +// Output finalizes the connection and returns the agent's result. +// +// Output is the single "I'm done" call: it implicitly closes the input +// side and drains any chunks the caller did not consume via Receive, +// then blocks until the agent finalizes. The drain is required because +// the underlying stream buffer is shallow; without it, a producing fn +// would block on chunk sends and never reach completion. Calling Close +// before Output is allowed but redundant; both are idempotent. +// +// Output is itself idempotent: subsequent calls return the same +// (*AgentOutput, error) from cache. The returned pointer is shared +// across calls; treat it as read-only. // -// Unlike the underlying BidiConnection, Output waits for the agent to -// finalize before returning. This is important for detached invocations: -// when the client sends Detach, the agent function returns promptly with a -// pending snapshot ID, and callers need to observe that output rather than -// the context cancellation error. +// Detach: when the client sends Detach, the agent function returns +// promptly with a pending snapshot ID. Output returns that output +// rather than a context cancellation error. +// +// Do not call Output concurrently with a goroutine iterating Receive — +// both consume from the same stream and chunks would be split between +// them. Sequence the calls: finish Receive first, then call Output. func (c *AgentConnection[Stream, State]) Output() (*AgentOutput[State], error) { + _ = c.conn.Close() + + drainDone := make(chan struct{}) + go func() { + defer close(drainDone) + for range c.conn.Receive() { + } + }() + <-c.conn.Done() + <-drainDone return c.conn.Output() } diff --git a/go/samples/agent-custom/main.go b/go/samples/agent-custom/main.go index a5b5ef9d0a..e83dcd372a 100644 --- a/go/samples/agent-custom/main.go +++ b/go/samples/agent-custom/main.go @@ -24,6 +24,7 @@ package main import ( "bufio" "context" + "errors" "fmt" "os" "os/signal" @@ -43,6 +44,12 @@ func main() { defer stop() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + store, err := localstore.NewFileSessionStore[any]("./sessions") + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + chatAgent := genkit.DefineCustomAgent(g, "chat", func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { @@ -71,7 +78,7 @@ func main() { } return sess.Result(), nil }, - aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), + aix.WithSessionStore(store), aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) @@ -83,7 +90,6 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - defer conn.Close() inputCh := readLines(ctx) @@ -136,7 +142,15 @@ repl: } } - fmt.Println(conn.Output()) + out, err := conn.Output() + if err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + if out != nil && out.SnapshotID != "" { + fmt.Printf("[snapshot: %s]\n", out.SnapshotID) + } + fmt.Println("You left the conversation. Goodbye!") } // readLines reads lines from stdin on a background goroutine and yields diff --git a/go/samples/agent-inline/main.go b/go/samples/agent-inline/main.go index 12000a2d74..df5af80f5d 100644 --- a/go/samples/agent-inline/main.go +++ b/go/samples/agent-inline/main.go @@ -24,6 +24,7 @@ package main import ( "bufio" "context" + "errors" "fmt" "os" "os/signal" @@ -70,7 +71,6 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - defer conn.Close() inputCh := readLines(ctx) @@ -122,6 +122,16 @@ repl: } } } + + out, err := conn.Output() + if err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + if out != nil && out.SnapshotID != "" { + fmt.Printf("[snapshot: %s]\n", out.SnapshotID) + } + fmt.Println("You left the conversation. Goodbye!") } // readLines reads lines from stdin on a background goroutine and yields diff --git a/go/samples/agent-prompt/main.go b/go/samples/agent-prompt/main.go index af29470c9a..65722e981b 100644 --- a/go/samples/agent-prompt/main.go +++ b/go/samples/agent-prompt/main.go @@ -24,6 +24,7 @@ package main import ( "bufio" "context" + "errors" "fmt" "os" "os/signal" @@ -47,9 +48,15 @@ func main() { genkit.DefineSchemaFor[ChatPromptInput](g) + store, err := localstore.NewFileSessionStore[any]("./sessions") + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + chatAgent := genkit.DefineAgent(g, "chat", aix.FromPrompt(ChatPromptInput{Personality: "a sarcastic pirate"}), - aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), + aix.WithSessionStore(store), aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 }), @@ -63,7 +70,6 @@ func main() { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - defer conn.Close() inputCh := readLines(ctx) @@ -115,6 +121,16 @@ repl: } } } + + out, err := conn.Output() + if err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + if out != nil && out.SnapshotID != "" { + fmt.Printf("[snapshot: %s]\n", out.SnapshotID) + } + fmt.Println("You left the conversation. Goodbye!") } // readLines reads lines from stdin on a background goroutine and yields From 53a9878fa6fc2493ad22b1aea666599ebd203da4 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 14 May 2026 12:31:13 -0700 Subject: [PATCH 083/141] refactor(go): rename ActionDesc schema fields and lowercase json tags Rename StreamOutSchema to StreamSchema and StreamInSchema to InitSchema on ActionDesc, and lowercase the json tags for the four schema fields (inputschema, outputschema, streamschema, initschema). --- go/core/action.go | 22 +++++++++++----------- go/core/api/action.go | 18 +++++++++--------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index c39f0cd7eb..576b78b02a 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -126,15 +126,15 @@ func NewBidiAction[In, Out, StreamOut, StreamIn any]( a.desc.OutputSchema = opts.OutputSchema } if opts.StreamOutSchema != nil { - a.desc.StreamOutSchema = opts.StreamOutSchema + a.desc.StreamSchema = opts.StreamOutSchema } if opts.StreamInSchema != nil { - a.desc.StreamInSchema = opts.StreamInSchema + a.desc.InitSchema = opts.StreamInSchema } else { var inStream StreamIn if reflect.ValueOf(inStream).Kind() != reflect.Invalid { - a.desc.StreamInSchema = InferSchemaMap(inStream) + a.desc.InitSchema = InferSchemaMap(inStream) } } @@ -236,14 +236,14 @@ func newAction[In, Out, StreamOut, StreamIn any]( return fn(ctx, input, cb) }, desc: &api.ActionDesc{ - Type: atype, - Key: api.KeyFromName(atype, name), - Name: name, - Description: description, - InputSchema: inputSchema, - OutputSchema: outputSchema, - StreamOutSchema: outStreamSchema, - Metadata: metadata, + Type: atype, + Key: api.KeyFromName(atype, name), + Name: name, + Description: description, + InputSchema: inputSchema, + OutputSchema: outputSchema, + StreamSchema: outStreamSchema, + Metadata: metadata, }, } } diff --git a/go/core/api/action.go b/go/core/api/action.go index d0df3ee613..91ee2a646d 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -68,13 +68,13 @@ const ( // ActionDesc is a descriptor of an action. type ActionDesc struct { - Type ActionType `json:"type"` // Type of the action. - Key string `json:"key"` // Key of the action. - Name string `json:"name"` // Name of the action. - Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. - StreamOutSchema map[string]any `json:"streamOutSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. - StreamInSchema map[string]any `json:"streamInSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). - Metadata map[string]any `json:"metadata"` // Metadata for the action. + Type ActionType `json:"type"` // Type of the action. + Key string `json:"key"` // Key of the action. + Name string `json:"name"` // Name of the action. + Description string `json:"description"` // Description of the action. + InputSchema map[string]any `json:"inputschema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputschema"` // JSON schema to validate against the action's output. + StreamSchema map[string]any `json:"streamschema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. + InitSchema map[string]any `json:"initschema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). + Metadata map[string]any `json:"metadata"` // Metadata for the action. } From 678bae9e3d622a5b3c2a65f95c222a84f98ce3ad Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 14 May 2026 12:31:13 -0700 Subject: [PATCH 084/141] refactor(go): rename ActionDesc schema fields and lowercase json tags Rename StreamOutSchema to StreamSchema and StreamInSchema to InitSchema on ActionDesc, and lowercase the json tags for the four schema fields (inputschema, outputschema, streamschema, initschema). --- go/core/action.go | 22 +++++++++++----------- go/core/api/action.go | 18 +++++++++--------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index c39f0cd7eb..576b78b02a 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -126,15 +126,15 @@ func NewBidiAction[In, Out, StreamOut, StreamIn any]( a.desc.OutputSchema = opts.OutputSchema } if opts.StreamOutSchema != nil { - a.desc.StreamOutSchema = opts.StreamOutSchema + a.desc.StreamSchema = opts.StreamOutSchema } if opts.StreamInSchema != nil { - a.desc.StreamInSchema = opts.StreamInSchema + a.desc.InitSchema = opts.StreamInSchema } else { var inStream StreamIn if reflect.ValueOf(inStream).Kind() != reflect.Invalid { - a.desc.StreamInSchema = InferSchemaMap(inStream) + a.desc.InitSchema = InferSchemaMap(inStream) } } @@ -236,14 +236,14 @@ func newAction[In, Out, StreamOut, StreamIn any]( return fn(ctx, input, cb) }, desc: &api.ActionDesc{ - Type: atype, - Key: api.KeyFromName(atype, name), - Name: name, - Description: description, - InputSchema: inputSchema, - OutputSchema: outputSchema, - StreamOutSchema: outStreamSchema, - Metadata: metadata, + Type: atype, + Key: api.KeyFromName(atype, name), + Name: name, + Description: description, + InputSchema: inputSchema, + OutputSchema: outputSchema, + StreamSchema: outStreamSchema, + Metadata: metadata, }, } } diff --git a/go/core/api/action.go b/go/core/api/action.go index d0df3ee613..91ee2a646d 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -68,13 +68,13 @@ const ( // ActionDesc is a descriptor of an action. type ActionDesc struct { - Type ActionType `json:"type"` // Type of the action. - Key string `json:"key"` // Key of the action. - Name string `json:"name"` // Name of the action. - Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. - StreamOutSchema map[string]any `json:"streamOutSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. - StreamInSchema map[string]any `json:"streamInSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). - Metadata map[string]any `json:"metadata"` // Metadata for the action. + Type ActionType `json:"type"` // Type of the action. + Key string `json:"key"` // Key of the action. + Name string `json:"name"` // Name of the action. + Description string `json:"description"` // Description of the action. + InputSchema map[string]any `json:"inputschema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputschema"` // JSON schema to validate against the action's output. + StreamSchema map[string]any `json:"streamschema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. + InitSchema map[string]any `json:"initschema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). + Metadata map[string]any `json:"metadata"` // Metadata for the action. } From a64602240e49121c31391372ed286dc4bfc7656a Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 14 May 2026 12:46:30 -0700 Subject: [PATCH 085/141] feat(go): unified basic-agents sample; add LatestSnapshot and Agent.Name() Consolidates agent-custom, agent-inline, and agent-prompt into a single basic-agents sample that exposes all three styles through one interactive CLI with agent picker, resume-from-last-snapshot, and /detach demo. Adds FileSessionStore.LatestSnapshot and InMemorySessionStore.LatestSnapshot so callers can surface "where did I leave off" without indexing the directory themselves, plus Agent.Name() for use in list/picker UIs. --- go/ai/exp/agent.go | 7 + go/ai/exp/agent_test.go | 11 + go/ai/exp/localstore/file.go | 64 +++ go/ai/exp/localstore/file_test.go | 79 +++ go/ai/exp/localstore/inmemory.go | 22 + go/ai/exp/localstore/inmemory_test.go | 53 ++ go/samples/agent-custom/main.go | 181 ------- go/samples/agent-inline/main.go | 162 ------ go/samples/agent-prompt/main.go | 161 ------ go/samples/basic-agents/cli.go | 505 ++++++++++++++++++ go/samples/basic-agents/main.go | 210 ++++++++ .../prompts/chef.prompt} | 2 +- 12 files changed, 952 insertions(+), 505 deletions(-) delete mode 100644 go/samples/agent-custom/main.go delete mode 100644 go/samples/agent-inline/main.go delete mode 100644 go/samples/agent-prompt/main.go create mode 100644 go/samples/basic-agents/cli.go create mode 100644 go/samples/basic-agents/main.go rename go/samples/{agent-prompt/prompts/chat.prompt => basic-agents/prompts/chef.prompt} (77%) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 5ef736befd..b76b8f4fc4 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -267,6 +267,13 @@ type Agent[Stream, State any] struct { action *core.Action[*AgentInit[State], *AgentOutput[State], *AgentStreamChunk[Stream], *AgentInput] } +// Name returns the agent's registered name. This is also the name under +// which any inline-defined prompt and companion actions (getSnapshot, +// abortSnapshot) are registered. +func (a *Agent[Stream, State]) Name() string { + return a.action.Name() +} + // DefineAgent defines a prompt-backed agent and registers it. Each turn // renders the agent's prompt, appends conversation history, calls the // model with streaming, and updates session state. diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 40891911ab..ea2e269339 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -3260,3 +3260,14 @@ func TestAgent_ResultAndOutput_IsolatedFromSession(t *testing.T) { t.Errorf("snapshot artifact tainted: got %q, want %q", snap.State.Artifacts[0].Name, "orig") } } + +func TestAgent_Name(t *testing.T) { + reg := newTestRegistry(t) + a := DefineCustomAgent(reg, "name-accessor", + func(ctx context.Context, _ Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return sess.Result(), nil + }) + if got := a.Name(); got != "name-accessor" { + t.Errorf("Name() = %q, want %q", got, "name-accessor") + } +} diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index 0c7a26a8dc..855dc1aec2 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -136,6 +136,70 @@ func (s *FileSessionStore[State]) SaveSnapshot( return next, nil } +// LatestSnapshot returns the snapshot whose backing file has the most +// recent on-disk modification time, or nil if the directory has no +// snapshots yet. +// +// Selecting by file mtime (rather than parsing every file to compare +// [SessionSnapshot.UpdatedAt]) makes the operation O(N) stats plus a +// single read of the winner in the common case, rather than O(N) +// reads + JSON parses. For snapshots written by this package mtime +// and UpdatedAt advance together (each save creates a fresh temp +// file and renames it into place), so the result is identical to a +// sort by UpdatedAt. If a snapshot file is touched externally, mtime +// wins. +// +// Files that fail to stat (e.g. removed between the directory read +// and the stat) or fail to parse are skipped silently and the scan +// falls back to the next-newest candidate, so a single corrupted +// file does not poison the result. The directory listing is not +// atomic with respect to concurrent writes — a snapshot that appears +// or disappears mid-scan may or may not be observed. +// +// This is not part of the [exp.SessionStore] interface; it is a +// FileSessionStore-specific convenience for UIs and CLIs that need to +// surface "where did I leave off" without indexing the directory +// themselves. +func (s *FileSessionStore[State]) LatestSnapshot(ctx context.Context) (*exp.SessionSnapshot[State], error) { + entries, err := os.ReadDir(s.dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("FileSessionStore: list dir: %w", err) + } + + type candidate struct { + name string + modTime time.Time + } + var cands []candidate + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") { + continue + } + info, err := e.Info() + if err != nil { + // Entry vanished between ReadDir and Info; ignore and keep + // scanning. Any caller-visible inconsistency is bounded to + // "a snapshot disappeared mid-scan" which the doc allows. + continue + } + cands = append(cands, candidate{e.Name(), info.ModTime()}) + } + slices.SortFunc(cands, func(a, b candidate) int { + return b.modTime.Compare(a.modTime) // newest first + }) + for _, c := range cands { + snap, err := s.GetSnapshot(ctx, strings.TrimSuffix(c.name, ".json")) + if err != nil || snap == nil { + continue + } + return snap, nil + } + return nil, nil +} + // AbortSnapshot atomically flips a pending snapshot to aborted. If the // snapshot is already terminal the existing status is returned unchanged. // Returns an empty status if the snapshot is not found. diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index c992e4d2f5..327efdf3a4 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -396,4 +396,83 @@ func TestFileSessionStore(t *testing.T) { var _ exp.SessionStore[testState] = (*FileSessionStore[testState])(nil) var _ exp.SnapshotAborter = (*FileSessionStore[testState])(nil) }) + + t.Run("LatestSnapshotEmpty", func(t *testing.T) { + store := newFileStore(t) + latest, err := store.LatestSnapshot(context.Background()) + if err != nil { + t.Fatalf("LatestSnapshot: %v", err) + } + if latest != nil { + t.Errorf("expected nil on empty store, got %+v", latest) + } + }) + + t.Run("LatestSnapshotMissingDir", func(t *testing.T) { + // Construct a store but then remove the dir to simulate a + // pre-use store. LatestSnapshot should treat "no dir" as + // "no snapshots" rather than error. + dir := filepath.Join(t.TempDir(), "nope") + store, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("RemoveAll: %v", err) + } + latest, err := store.LatestSnapshot(context.Background()) + if err != nil { + t.Fatalf("LatestSnapshot: %v", err) + } + if latest != nil { + t.Errorf("expected nil on missing dir, got %+v", latest) + } + }) + + t.Run("LatestSnapshotReturnsMostRecent", func(t *testing.T) { + store := newFileStore(t) + ctx := context.Background() + for _, id := range []string{"snap-1", "snap-2", "snap-3"} { + if _, err := store.SaveSnapshot(ctx, id, + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }); err != nil { + t.Fatalf("SaveSnapshot %q: %v", id, err) + } + // Force monotonically increasing UpdatedAt — the store + // stamps to wall-clock time and successive writes within + // the same tick can tie. + time.Sleep(2 * time.Millisecond) + } + latest, err := store.LatestSnapshot(ctx) + if err != nil { + t.Fatalf("LatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "snap-3" { + t.Errorf("expected latest=snap-3, got %+v", latest) + } + }) + + t.Run("LatestSnapshotSkipsCorruptedFiles", func(t *testing.T) { + store := newFileStore(t) + ctx := context.Background() + good, err := store.SaveSnapshot(ctx, "good", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot good: %v", err) + } + // Drop an unparseable file alongside. + if err := os.WriteFile(filepath.Join(store.dir, "broken.json"), []byte("not json"), 0o600); err != nil { + t.Fatalf("seed broken file: %v", err) + } + latest, err := store.LatestSnapshot(ctx) + if err != nil { + t.Fatalf("LatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != good.SnapshotID { + t.Errorf("expected to skip broken.json and return %q, got %+v", good.SnapshotID, latest) + } + }) } diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go index a50f47d13f..4fd179c6a3 100644 --- a/go/ai/exp/localstore/inmemory.go +++ b/go/ai/exp/localstore/inmemory.go @@ -65,6 +65,28 @@ func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID return copySnapshot(snap) } +// LatestSnapshot returns the snapshot with the most recent +// [SessionSnapshot.UpdatedAt] in the store, or nil if there are none. +// +// This is not part of the [exp.SessionStore] interface; it is an +// InMemorySessionStore-specific convenience that mirrors +// [FileSessionStore.LatestSnapshot] so callers that swap stores during +// tests don't have to special-case the in-memory implementation. +func (s *InMemorySessionStore[State]) LatestSnapshot(_ context.Context) (*exp.SessionSnapshot[State], error) { + s.mu.RLock() + defer s.mu.RUnlock() + var latest *exp.SessionSnapshot[State] + for _, snap := range s.snapshots { + if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) { + latest = snap + } + } + if latest == nil { + return nil, nil + } + return copySnapshot(latest) +} + // AbortSnapshot atomically flips a pending snapshot to aborted. If the // snapshot is already terminal the existing status is returned unchanged. // Returns an empty status if the snapshot is not found. diff --git a/go/ai/exp/localstore/inmemory_test.go b/go/ai/exp/localstore/inmemory_test.go index 0f8c5c8fc3..312eaa2d7f 100644 --- a/go/ai/exp/localstore/inmemory_test.go +++ b/go/ai/exp/localstore/inmemory_test.go @@ -156,4 +156,57 @@ func TestInMemorySessionStore(t *testing.T) { var _ exp.SessionStore[testState] = (*InMemorySessionStore[testState])(nil) var _ exp.SnapshotAborter = (*InMemorySessionStore[testState])(nil) }) + + t.Run("LatestSnapshotEmpty", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + latest, err := store.LatestSnapshot(context.Background()) + if err != nil { + t.Fatalf("LatestSnapshot: %v", err) + } + if latest != nil { + t.Errorf("expected nil on empty store, got %+v", latest) + } + }) + + t.Run("LatestSnapshotReturnsMostRecent", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + ctx := context.Background() + for _, id := range []string{"snap-1", "snap-2", "snap-3"} { + if _, err := store.SaveSnapshot(ctx, id, + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + }); err != nil { + t.Fatalf("SaveSnapshot %q: %v", id, err) + } + // Wall-clock UpdatedAt can tie within a tick; force order. + time.Sleep(2 * time.Millisecond) + } + latest, err := store.LatestSnapshot(ctx) + if err != nil { + t.Fatalf("LatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "snap-3" { + t.Errorf("expected latest=snap-3, got %+v", latest) + } + }) + + t.Run("LatestSnapshotReturnsCopy", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + ctx := context.Background() + if _, err := store.SaveSnapshot(ctx, "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + first, _ := store.LatestSnapshot(ctx) + first.State.Custom.Counter = 999 + second, _ := store.LatestSnapshot(ctx) + if second.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 (isolation), got %d", second.State.Custom.Counter) + } + }) } diff --git a/go/samples/agent-custom/main.go b/go/samples/agent-custom/main.go deleted file mode 100644 index e83dcd372a..0000000000 --- a/go/samples/agent-custom/main.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This sample demonstrates DefineCustomAgent, which gives the caller -// full control over the per-turn loop (model selection, history -// management, streaming chunks). It runs a CLI REPL where conversation -// history is maintained by the session. Compare with agent-prompt -// (DefineAgent + aix.FromPrompt) and agent-inline (DefineAgent + -// aix.FromInline), which both auto-wire the loop and differ only in -// where the prompt is defined. -package main - -import ( - "bufio" - "context" - "errors" - "fmt" - "os" - "os/signal" - "strings" - "syscall" - - "github.com/firebase/genkit/go/ai" - aix "github.com/firebase/genkit/go/ai/exp" - "github.com/firebase/genkit/go/ai/exp/localstore" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "google.golang.org/genai" -) - -func main() { - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - - store, err := localstore.NewFileSessionStore[any]("./sessions") - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - - chatAgent := genkit.DefineCustomAgent(g, "chat", - func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { - for chunk, err := range genkit.GenerateStream(ctx, g, - ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a helpful assistant. Keep responses concise."), - ai.WithMessages(sess.Messages()...), - ) { - if err != nil { - return err - } - if chunk.Done { - sess.AddMessages(chunk.Response.Message) - break - } - resp.SendModelChunk(chunk.Chunk) - } - - return nil - }); err != nil { - return nil, err - } - return sess.Result(), nil - }, - aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), - ) - - fmt.Println("Agent Chat (type 'quit' to exit, Ctrl+C to abort)") - fmt.Println() - - conn, err := chatAgent.StreamBidi(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - - inputCh := readLines(ctx) - -repl: - for { - fmt.Print("> ") - var input string - select { - case <-ctx.Done(): - fmt.Println() - break repl - case line, ok := <-inputCh: - if !ok { - fmt.Println() - break repl - } - input = strings.TrimSpace(line) - } - - if input == "quit" || input == "exit" { - break - } - if input == "" { - continue - } - - if err := conn.SendText(input); err != nil { - fmt.Fprintf(os.Stderr, "Send error: %v\n", err) - break - } - - fmt.Println() - - for chunk, err := range conn.Receive() { - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - break - } - if chunk.ModelChunk != nil { - fmt.Print(chunk.ModelChunk.Text()) - } - if chunk.TurnEnd != nil { - if chunk.TurnEnd.SnapshotID != "" { - fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) - } - fmt.Println() - fmt.Println() - break - } - } - } - - out, err := conn.Output() - if err != nil && !errors.Is(err, context.Canceled) { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - if out != nil && out.SnapshotID != "" { - fmt.Printf("[snapshot: %s]\n", out.SnapshotID) - } - fmt.Println("You left the conversation. Goodbye!") -} - -// readLines reads lines from stdin on a background goroutine and yields -// them via the returned channel. The channel is closed on EOF, read error, -// or ctx cancellation. The goroutine cannot interrupt a blocked stdin -// read; on ctx cancellation it exits as soon as a line completes (or, in -// practice, when the process terminates). -func readLines(ctx context.Context) <-chan string { - ch := make(chan string) - go func() { - defer close(ch) - reader := bufio.NewReader(os.Stdin) - for { - line, err := reader.ReadString('\n') - if line != "" { - select { - case ch <- line: - case <-ctx.Done(): - return - } - } - if err != nil { - return - } - } - }() - return ch -} diff --git a/go/samples/agent-inline/main.go b/go/samples/agent-inline/main.go deleted file mode 100644 index df5af80f5d..0000000000 --- a/go/samples/agent-inline/main.go +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This sample demonstrates DefineAgent with aix.FromInline, which -// creates a multi-turn conversational agent backed by a prompt defined -// inline alongside the agent. The conversation loop (render prompt, -// call model, stream chunks, update history) is handled automatically. -// Compare with agent-prompt (DefineAgent + aix.FromPrompt), which -// sources its prompt from a .prompt file, and agent-custom -// (DefineCustomAgent), which wires the same loop manually. -package main - -import ( - "bufio" - "context" - "errors" - "fmt" - "os" - "os/signal" - "strings" - "syscall" - - "github.com/firebase/genkit/go/ai" - aix "github.com/firebase/genkit/go/ai/exp" - "github.com/firebase/genkit/go/ai/exp/localstore" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "google.golang.org/genai" -) - -func main() { - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - - store, err := localstore.NewFileSessionStore[any]("./sessions") - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - - chatAgent := genkit.DefineAgent(g, "chat", - aix.FromInline( - ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), - ), - aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), - ) - - fmt.Println("Agent Chat (type 'quit' to exit, Ctrl+C to abort)") - fmt.Println() - - conn, err := chatAgent.StreamBidi(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - - inputCh := readLines(ctx) - -repl: - for { - fmt.Print("> ") - var input string - select { - case <-ctx.Done(): - fmt.Println() - break repl - case line, ok := <-inputCh: - if !ok { - fmt.Println() - break repl - } - input = strings.TrimSpace(line) - } - - if input == "quit" || input == "exit" { - break - } - if input == "" { - continue - } - - if err := conn.SendText(input); err != nil { - fmt.Fprintf(os.Stderr, "Send error: %v\n", err) - break - } - - fmt.Println() - - for chunk, err := range conn.Receive() { - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - break - } - if chunk.ModelChunk != nil { - fmt.Print(chunk.ModelChunk.Text()) - } - if chunk.TurnEnd != nil { - if chunk.TurnEnd.SnapshotID != "" { - fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) - } - fmt.Println() - fmt.Println() - break - } - } - } - - out, err := conn.Output() - if err != nil && !errors.Is(err, context.Canceled) { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - if out != nil && out.SnapshotID != "" { - fmt.Printf("[snapshot: %s]\n", out.SnapshotID) - } - fmt.Println("You left the conversation. Goodbye!") -} - -// readLines reads lines from stdin on a background goroutine and yields -// them via the returned channel. The channel is closed on EOF, read error, -// or ctx cancellation. The goroutine cannot interrupt a blocked stdin -// read; on ctx cancellation it exits as soon as a line completes (or, in -// practice, when the process terminates). -func readLines(ctx context.Context) <-chan string { - ch := make(chan string) - go func() { - defer close(ch) - reader := bufio.NewReader(os.Stdin) - for { - line, err := reader.ReadString('\n') - if line != "" { - select { - case ch <- line: - case <-ctx.Done(): - return - } - } - if err != nil { - return - } - } - }() - return ch -} diff --git a/go/samples/agent-prompt/main.go b/go/samples/agent-prompt/main.go deleted file mode 100644 index 65722e981b..0000000000 --- a/go/samples/agent-prompt/main.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This sample demonstrates DefineAgent with aix.FromPrompt, which -// creates a multi-turn conversational agent backed by a .prompt file. -// The conversation loop (render prompt, call model, stream chunks, -// update history) is handled automatically. Compare with agent-custom -// (DefineCustomAgent), which wires the same loop manually, and -// agent-inline (DefineAgent + aix.FromInline), which defines the -// prompt inline alongside the agent. -package main - -import ( - "bufio" - "context" - "errors" - "fmt" - "os" - "os/signal" - "strings" - "syscall" - - aix "github.com/firebase/genkit/go/ai/exp" - "github.com/firebase/genkit/go/ai/exp/localstore" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" -) - -type ChatPromptInput struct { - Personality string `json:"personality"` -} - -func main() { - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - - genkit.DefineSchemaFor[ChatPromptInput](g) - - store, err := localstore.NewFileSessionStore[any]("./sessions") - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - - chatAgent := genkit.DefineAgent(g, "chat", - aix.FromPrompt(ChatPromptInput{Personality: "a sarcastic pirate"}), - aix.WithSessionStore(store), - aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { - return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 - }), - ) - - fmt.Println("Agent Chat (type 'quit' to exit, Ctrl+C to abort)") - fmt.Println() - - conn, err := chatAgent.StreamBidi(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - - inputCh := readLines(ctx) - -repl: - for { - fmt.Print("> ") - var input string - select { - case <-ctx.Done(): - fmt.Println() - break repl - case line, ok := <-inputCh: - if !ok { - fmt.Println() - break repl - } - input = strings.TrimSpace(line) - } - - if input == "quit" || input == "exit" { - break - } - if input == "" { - continue - } - - if err := conn.SendText(input); err != nil { - fmt.Fprintf(os.Stderr, "Send error: %v\n", err) - break - } - - fmt.Println() - - for chunk, err := range conn.Receive() { - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - break - } - if chunk.ModelChunk != nil { - fmt.Print(chunk.ModelChunk.Text()) - } - if chunk.TurnEnd != nil { - if chunk.TurnEnd.SnapshotID != "" { - fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) - } - fmt.Println() - fmt.Println() - break - } - } - } - - out, err := conn.Output() - if err != nil && !errors.Is(err, context.Canceled) { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - if out != nil && out.SnapshotID != "" { - fmt.Printf("[snapshot: %s]\n", out.SnapshotID) - } - fmt.Println("You left the conversation. Goodbye!") -} - -// readLines reads lines from stdin on a background goroutine and yields -// them via the returned channel. The channel is closed on EOF, read error, -// or ctx cancellation. The goroutine cannot interrupt a blocked stdin -// read; on ctx cancellation it exits as soon as a line completes (or, in -// practice, when the process terminates). -func readLines(ctx context.Context) <-chan string { - ch := make(chan string) - go func() { - defer close(ch) - reader := bufio.NewReader(os.Stdin) - for { - line, err := reader.ReadString('\n') - if line != "" { - select { - case ch <- line: - case <-ctx.Done(): - return - } - } - if err != nil { - return - } - } - }() - return ch -} diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go new file mode 100644 index 0000000000..f17558b73f --- /dev/null +++ b/go/samples/basic-agents/cli.go @@ -0,0 +1,505 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is the user-facing half of the sample: the interactive CLI +// that ties the three server-side agent definitions in main.go to a +// single experience. The CLI is deliberately small. It has three +// responsibilities: +// +// 1. Pick an agent ("thread list"). +// 2. For the picked agent, pick between resuming from the latest +// snapshot or starting fresh. If the latest snapshot is still +// pending (a detached invocation is still processing in the +// background), offer to wait, start fresh, or back out. +// 3. Run a small REPL against the agent: stream the model's reply each +// turn, accept text input, and offer /detach, /back, and /quit as +// control commands. +// +// The detach demo is woven into step 3: typing "/detach " sends +// the text as the final input and detaches, so the agent keeps +// processing in the background and the caller gets a pending snapshot +// ID. Re-picking the same agent in step 2 then surfaces the wait/resume +// flow. + +package main + +import ( + "bufio" + "context" + "errors" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" +) + +// sampleAgent pairs an agent with the store it persists to and a +// one-line description for the CLI list view. The embedded +// *aix.Agent[any, any] makes Name(), StreamBidi() etc. callable +// directly on a sampleAgent value, so the CLI does not need a +// separate field-threading layer. +// +// Store is tracked alongside the agent (rather than fished out of it) +// because we use FileSessionStore-specific helpers like +// LatestSnapshot and OnSnapshotStatusChange; carrying the concrete +// type avoids a type assertion at every call site. +type sampleAgent struct { + *aix.Agent[any, any] + Store *localstore.FileSessionStore[any] + Description string +} + +// errQuit signals that the user typed /quit somewhere in the CLI; it +// bubbles up through openAgent and breaks runCLI's outer loop. +var errQuit = errors.New("quit") + +// runCLI is the entry point for the interactive client. It alternates +// between two screens forever: the agent list and a per-agent chat. +// Returning from a chat brings the user back to the agent list. /quit +// (anywhere) and Ctrl-C both unwind back here and exit cleanly. +func runCLI(ctx context.Context, agents []sampleAgent) error { + fmt.Println("Genkit Basic Agents") + fmt.Println("===================") + fmt.Println() + fmt.Println("Pick an agent below, choose to resume the last conversation") + fmt.Println("or start a new one, and chat. Inside a chat:") + fmt.Println(" (text) send a message and stream the reply") + fmt.Println(" /detach (text...) send the text (optional) as the final") + fmt.Println(" input, then detach. The agent finishes") + fmt.Println(" in the background and writes a pending") + fmt.Println(" snapshot. Re-pick the agent later to") + fmt.Println(" wait for it and resume from the final") + fmt.Println(" state.") + fmt.Println(" /back return to the agent list") + fmt.Println(" /quit exit the program") + + inputCh := readLines(ctx) + for { + choice, ok := pickAgent(ctx, inputCh, agents) + if !ok { + return nil + } + if err := openAgent(ctx, inputCh, agents[choice]); err != nil { + if errors.Is(err, errQuit) { + return nil + } + return err + } + } +} + +// pickAgent renders the agent list and reads the user's choice. The +// list is re-rendered between selections so the user can see updated +// pending/terminal status after returning from a chat. +func pickAgent(ctx context.Context, inputCh <-chan string, agents []sampleAgent) (int, bool) { + for { + fmt.Println() + fmt.Println("Agents:") + for i, a := range agents { + fmt.Printf(" %d. %s — %s\n", i+1, a.Name(), a.Description) + if summary := summarizeLatest(ctx, a); summary != "" { + fmt.Printf(" last: %s\n", summary) + } + } + fmt.Println(" q. quit") + fmt.Println() + fmt.Print("> ") + + line, ok := readLine(ctx, inputCh) + if !ok { + return -1, false + } + line = strings.TrimSpace(line) + if line == "q" || line == "quit" || line == "exit" { + return -1, false + } + idx, err := strconv.Atoi(line) + if err != nil || idx < 1 || idx > len(agents) { + fmt.Println("Invalid choice. Type a number from the list or 'q' to quit.") + continue + } + return idx - 1, true + } +} + +// openAgent is the per-agent screen: it surfaces the latest snapshot +// (asking the user how to handle a still-pending detached invocation), +// asks whether to resume or start fresh, and then hands off to the +// chat REPL. +// +// The pending and non-pending paths return the same (resume, ok) shape, +// so the rest of the flow is uniform: ok=false means the user backed +// out, otherwise hand the chosen snapshot (or nil for fresh) to +// runChat. +func openAgent(ctx context.Context, inputCh <-chan string, a sampleAgent) error { + latest, err := a.Store.LatestSnapshot(ctx) + if err != nil { + return fmt.Errorf("read snapshots for %q: %w", a.Name(), err) + } + + var ( + resume *aix.SessionSnapshot[any] + ok bool + ) + if latest != nil && latest.Status == aix.SnapshotStatusPending { + // Background invocation still in flight. handlePending makes + // the final decision itself (wait & resume, new, or back), so + // we don't fall through to pickSession — the user already + // chose; asking again would just be noise. + resume, ok = handlePending(ctx, inputCh, a, latest) + } else { + resume, ok = pickSession(ctx, inputCh, a, latest) + } + if !ok { + return nil + } + return runChat(ctx, inputCh, a, resume) +} + +// handlePending offers the three reasonable responses when a previous +// invocation of this agent is still running in the background: +// +// 1. wait for it to finalize and resume from it directly, +// 2. ignore it and start a fresh conversation, +// 3. go back to the agent list. +// +// Returns the snapshot to resume from (option 1, succeeded) or nil +// (option 2, or option 1 when the snapshot terminated non-succeeded). +// ok=false means the user chose 3 or the context was canceled. +// +// Crucially, options that imply "use this conversation" return the +// snapshot directly so the caller can skip the resume / new prompt: +// the user already committed to the choice by waiting, and re-asking +// would be redundant. +func handlePending(ctx context.Context, inputCh <-chan string, a sampleAgent, pending *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { + for { + fmt.Printf("\nThe last %s session is still running in the background (%s).\n", a.Name(), shortID(pending.SnapshotID)) + fmt.Println(" 1. Wait for it to finalize") + fmt.Println(" 2. Start a new conversation") + fmt.Println(" 3. Back to agent list") + fmt.Print("> ") + + line, ok := readLine(ctx, inputCh) + if !ok { + return nil, false + } + switch strings.TrimSpace(line) { + case "1": + fmt.Println("Waiting for it to finalize...") + final, err := waitForFinalize(ctx, a.Store, pending.SnapshotID) + if err != nil { + fmt.Fprintf(os.Stderr, "Wait error: %v\n", err) + return nil, false + } + if final == nil { + fmt.Println("Snapshot disappeared while waiting. Starting a new conversation.") + return nil, true + } + fmt.Printf("Done (%s).\n", final.Status) + if final.Status != aix.SnapshotStatusSucceeded { + // failed / aborted snapshots aren't resumable; the + // agent runtime would reject WithSnapshotID on them. + // Fall through to a fresh chat instead. + fmt.Println("Cannot resume this snapshot. Starting a new conversation.") + return nil, true + } + return final, true + case "2": + // Ignore the pending snapshot; start a fresh chat. The + // background invocation keeps running and writes its + // terminal status — this CLI just stops tracking it. + return nil, true + case "3": + return nil, false + default: + fmt.Println("Invalid choice. Type 1, 2, or 3.") + } + } +} + +// pickSession decides which snapshot (if any) to resume from. It only +// offers two paths so the demo stays focused: resume from the most +// recent terminal snapshot (returns the snapshot pointer), or start +// fresh (returns nil). +func pickSession(ctx context.Context, inputCh <-chan string, a sampleAgent, latest *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { + if latest == nil || latest.Status != aix.SnapshotStatusSucceeded { + fmt.Printf("\nStarting a new conversation with %s.\n", a.Name()) + return nil, true + } + + msgs := 0 + if latest.State != nil { + msgs = len(latest.State.Messages) + } + fmt.Printf("\nLast %s session: %s (%s, %d msgs).\n", + a.Name(), shortID(latest.SnapshotID), latest.UpdatedAt.Format(time.RFC822), msgs) + fmt.Println("Resume from it? [Y/n] (n = start a new conversation)") + fmt.Print("> ") + + line, ok := readLine(ctx, inputCh) + if !ok { + return nil, false + } + switch strings.ToLower(strings.TrimSpace(line)) { + case "", "y", "yes": + return latest, true + default: + return nil, true + } +} + +// runChat opens the bidi connection (optionally resuming from a +// snapshot) and runs the per-turn REPL. When resuming, the prior +// conversation is replayed first so the user sees the context they're +// picking up, then the REPL takes over. /detach is the one interesting +// branch — it sends the optional trailing text as the final input and +// detaches the connection, returning the pending snapshot ID for the +// user to observe. +func runChat(ctx context.Context, inputCh <-chan string, a sampleAgent, resume *aix.SessionSnapshot[any]) error { + fmt.Printf("\n=== Chatting with %s ===\n", a.Name()) + if resume != nil { + fmt.Printf("Resumed from %s\n", shortID(resume.SnapshotID)) + } + fmt.Println("Commands: /detach [text], /back, /quit") + + if resume != nil && resume.State != nil && len(resume.State.Messages) > 0 { + fmt.Println() + fmt.Println("(picking up where you left off)") + printHistory(resume.State.Messages) + } + fmt.Println() + + var opts []aix.InvocationOption[any] + if resume != nil { + opts = append(opts, aix.WithSnapshotID[any](resume.SnapshotID)) + } + conn, err := a.StreamBidi(ctx, opts...) + if err != nil { + return fmt.Errorf("open agent %q: %w", a.Name(), err) + } + + var ( + detached bool + quit bool + ) + +repl: + for { + fmt.Print("> ") + line, ok := readLine(ctx, inputCh) + if !ok { + break + } + text := strings.TrimSpace(line) + if text == "" { + continue + } + + switch { + case text == "/back": + break repl + case text == "/quit" || text == "/exit": + quit = true + break repl + case text == "/detach" || strings.HasPrefix(text, "/detach "): + trailing := strings.TrimSpace(strings.TrimPrefix(text, "/detach")) + // Send (optional message) + detach in a single wire input so + // the trailing text becomes the last buffered message. The + // agent will process it in the background after the + // connection closes. + input := &aix.AgentInput{Detach: true} + if trailing != "" { + input.Message = ai.NewUserTextMessage(trailing) + } + if err := conn.Send(input); err != nil { + fmt.Fprintf(os.Stderr, "Detach error: %v\n", err) + break repl + } + detached = true + break repl + case strings.HasPrefix(text, "/"): + fmt.Println("Unknown command. Try /detach, /back, or /quit.") + continue + } + + if err := conn.SendText(text); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + fmt.Println() + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) + } + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + fmt.Printf("\n [snapshot %s]", shortID(chunk.TurnEnd.SnapshotID)) + } + fmt.Println() + fmt.Println() + break + } + } + } + + out, outErr := conn.Output() + if outErr != nil && !errors.Is(outErr, context.Canceled) { + fmt.Fprintf(os.Stderr, "Output error: %v\n", outErr) + } + + switch { + case detached && out != nil && out.SnapshotID != "": + fmt.Printf("Detached. Pending snapshot: %s.\n", shortID(out.SnapshotID)) + fmt.Println("The agent keeps processing in the background. Pick this") + fmt.Println("agent again from the list to wait for it to finalize and") + fmt.Println("resume from the cumulative final state.") + case out != nil && out.SnapshotID != "": + fmt.Printf("Done. Final snapshot: %s.\n", shortID(out.SnapshotID)) + } + + if quit { + return errQuit + } + return nil +} + +// summarizeLatest is the one-line summary printed under each agent in +// the list. Empty if there is no snapshot yet, so a freshly-installed +// sample doesn't show clutter. +func summarizeLatest(ctx context.Context, a sampleAgent) string { + latest, err := a.Store.LatestSnapshot(ctx) + if err != nil || latest == nil { + return "" + } + msgs := 0 + if latest.State != nil { + msgs = len(latest.State.Messages) + } + return fmt.Sprintf("%s (%s, %d msgs, %s)", + shortID(latest.SnapshotID), latest.Status, msgs, latest.UpdatedAt.Format(time.RFC822)) +} + +// waitForFinalize subscribes to a snapshot's status and blocks until it +// transitions out of pending. The returned snapshot is the final one (or +// nil if it disappeared). OnSnapshotStatusChange yields the current +// status first, so a snapshot that finalized between the directory scan +// and the subscription is observed immediately. +func waitForFinalize(ctx context.Context, store *localstore.FileSessionStore[any], snapshotID string) (*aix.SessionSnapshot[any], error) { + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + statusCh := store.OnSnapshotStatusChange(subCtx, snapshotID) + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case status, ok := <-statusCh: + if !ok { + // Subscription closed (e.g. snapshot deleted under us). + return store.GetSnapshot(ctx, snapshotID) + } + if status == aix.SnapshotStatusPending { + continue + } + return store.GetSnapshot(ctx, snapshotID) + } + } +} + +// printHistory replays prior turns in the same format the live REPL +// uses, so a resumed chat reads continuously rather than dumping the +// user into an empty prompt. Non-user/model roles (e.g. tool messages) +// and empty content are skipped — they would only be noise here, and +// the agent's per-turn loop still has the full history under the hood. +func printHistory(msgs []*ai.Message) { + for _, m := range msgs { + text := strings.TrimSpace(m.Text()) + if text == "" { + continue + } + switch m.Role { + case ai.RoleUser: + fmt.Println() + fmt.Printf("> %s\n", text) + case ai.RoleModel: + fmt.Println() + fmt.Println(text) + } + } +} + +// shortID trims a UUID-shaped snapshot ID to its first segment so the +// CLI stays readable. The full ID is still available on disk in +// .genkit/snapshots//. +func shortID(id string) string { + if i := strings.Index(id, "-"); i > 0 { + return id[:i] + } + if len(id) > 8 { + return id[:8] + } + return id +} + +// readLine reads one line from inputCh, returning false if the channel +// closes (EOF) or ctx is canceled (Ctrl-C). All CLI prompts go through +// this so cancellation is honored uniformly. +func readLine(ctx context.Context, inputCh <-chan string) (string, bool) { + select { + case <-ctx.Done(): + fmt.Println() + return "", false + case line, ok := <-inputCh: + if !ok { + fmt.Println() + return "", false + } + return line, true + } +} + +// readLines reads lines from stdin on a background goroutine and yields +// them via the returned channel. The channel is closed on EOF, read +// error, or ctx cancellation. The goroutine cannot interrupt a blocked +// stdin read; on ctx cancellation it exits as soon as a line completes +// (or, in practice, when the process terminates). +func readLines(ctx context.Context) <-chan string { + ch := make(chan string) + go func() { + defer close(ch) + reader := bufio.NewReader(os.Stdin) + for { + line, err := reader.ReadString('\n') + if line != "" { + select { + case ch <- line: + case <-ctx.Done(): + return + } + } + if err != nil { + return + } + } + }() + return ch +} diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go new file mode 100644 index 0000000000..643cba5885 --- /dev/null +++ b/go/samples/basic-agents/main.go @@ -0,0 +1,210 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates Genkit's agent APIs by defining three agents in +// three different styles and exposing all of them through a single CLI: +// +// - "pirate" uses DefineAgent + aix.FromInline. The prompt is declared +// inline next to the agent. +// - "chef" uses DefineAgent + aix.FromPrompt. The prompt is loaded from +// ./prompts/chef.prompt by the agent's name. +// - "coder" uses DefineCustomAgent. The per-turn loop (model selection, +// history management, streaming) is wired by hand. +// +// All three agents persist their conversation state to a per-agent +// FileSessionStore under ./.genkit/snapshots//. +// +// To run: +// +// go run . +// +// The CLI prints a numbered list of agents. Pick one, choose to resume +// from the last snapshot or start fresh, and chat. Inside a chat: +// +// (text) send a message and stream the reply +// /detach (text...) send the text (optional) as the final input, then +// detach. The server keeps processing in the +// background; you get a pending snapshot ID and +// return to the agent list. Re-pick the agent later +// to wait for the snapshot to finalize and resume +// from the cumulative final state. +// /back return to the agent list (snapshot is still +// written by the agent's normal turn-end hook) +// /quit exit the program +// +// Tip: try "/detach write me a long pirate story" to see the detach loop +// end-to-end. After the CLI returns to the agent list, pick "pirate" +// again; if the snapshot is still pending, you'll get a three-way menu +// (wait, start new, back). Picking wait blocks on the in-process +// status subscription and resumes from the cumulative final state. +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +// ChatPromptInput is the input schema referenced by ./prompts/chef.prompt. +// Registering it via DefineSchemaFor lets the .prompt file refer to it by +// name in its YAML frontmatter. +type ChatPromptInput struct { + Personality string `json:"personality"` +} + +func main() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + genkit.DefineSchemaFor[ChatPromptInput](g) + + // Each define function returns a fully-populated sampleAgent that + // pairs the registered agent with its FileSessionStore and a + // one-line description for the CLI list view. The CLI then calls + // a.Name(), a.StreamBidi(...) on the embedded agent and a.Store + // for snapshot operations. + agents := []sampleAgent{ + defineInlineAgent(g), + definePromptAgent(g), + defineCustomAgent(g), + } + + if err := runCLI(ctx, agents); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +// defineInlineAgent demonstrates DefineAgent with aix.FromInline. The +// prompt is declared right next to the agent definition; the registered +// prompt and the agent share a name. Each turn the framework renders the +// prompt, appends the conversation history, calls the model, and updates +// session state. This is the shortest path from "I want a chat agent" to +// a working one. +func defineInlineAgent(g *genkit.Genkit) sampleAgent { + const name = "pirate" + store := mustStore(name) + return sampleAgent{ + Agent: genkit.DefineAgent(g, name, + aix.FromInline( + ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), + ), + aix.WithSessionStore(store), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ), + Store: store, + Description: "Sarcastic pirate (inline-defined prompt)", + } +} + +// definePromptAgent demonstrates DefineAgent with aix.FromPrompt. The +// prompt is loaded from ./prompts/.prompt by genkit's prompt +// registry. Defining the prompt in a file lets you tune model, config, +// schema, and template independently of the Go code — useful when prompt +// authors are not the same people writing the agent wiring. +// +// FromPrompt's argument is the default input passed to the prompt's +// Render on every turn; the inline-prompt variant has no per-turn input +// of its own. +func definePromptAgent(g *genkit.Genkit) sampleAgent { + const name = "chef" + store := mustStore(name) + return sampleAgent{ + Agent: genkit.DefineAgent(g, name, + aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), + aix.WithSessionStore(store), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ), + Store: store, + Description: "Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)", + } +} + +// defineCustomAgent demonstrates DefineCustomAgent. The per-turn function +// is fully under your control: it picks the model, manages the message +// list, streams chunks back to the client, and decides what to put in the +// final result. Use this form when the prompt-backed agent loop doesn't +// fit (e.g. you want to pre/post-process every turn, swap models +// dynamically, or wire up custom tool plumbing). +// +// Even with full control over the loop, the framework still owns session +// state, snapshot writes, and the detach lifecycle. +func defineCustomAgent(g *genkit.Genkit) sampleAgent { + const name = "coder" + store := mustStore(name) + return sampleAgent{ + Agent: genkit.DefineCustomAgent(g, name, + func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { + for chunk, err := range genkit.GenerateStream(ctx, g, + ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a senior software engineer. Answer briefly. Use fenced code blocks when showing code."), + ai.WithMessages(sess.Messages()...), + ) { + if err != nil { + return err + } + if chunk.Done { + sess.AddMessages(chunk.Response.Message) + break + } + resp.SendModelChunk(chunk.Chunk) + } + return nil + }); err != nil { + return nil, err + } + return sess.Result(), nil + }, + aix.WithSessionStore(store), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ), + Store: store, + Description: "Concise code helper (custom per-turn loop)", + } +} + +// mustStore creates a FileSessionStore rooted at the per-agent dir under +// ./.genkit/snapshots/, or exits the process on failure. Used during +// agent setup where there's nowhere sensible to return an error. +// +// Per-agent dirs keep one agent's history out of another's listing +// (FileSessionStore.LatestSnapshot scans the directory it was given). +func mustStore(agentName string) *localstore.FileSessionStore[any] { + store, err := localstore.NewFileSessionStore[any]("./.genkit/snapshots/" + agentName) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating store for %q: %v\n", agentName, err) + os.Exit(1) + } + return store +} diff --git a/go/samples/agent-prompt/prompts/chat.prompt b/go/samples/basic-agents/prompts/chef.prompt similarity index 77% rename from go/samples/agent-prompt/prompts/chat.prompt rename to go/samples/basic-agents/prompts/chef.prompt index 4c11a18f72..c40edfcbaa 100644 --- a/go/samples/agent-prompt/prompts/chat.prompt +++ b/go/samples/basic-agents/prompts/chef.prompt @@ -1,5 +1,5 @@ --- -model: googleai/gemini-3-flash-preview +model: googleai/gemini-flash-latest config: thinkingConfig: thinkingBudget: 0 From e760e944719acecd6b01fe22da37a8c9c0b08071 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 9 Jun 2026 14:31:02 -0700 Subject: [PATCH 086/141] fix(go/exp/localstore): never drop the latest snapshot status notification notifyLocked used a non-blocking send, so a new status was silently dropped whenever the subscriber's size-1 buffer still held an unread value (typically the seed delivered at subscription time). A detached invocation's abort watcher could miss the aborted transition that way, leaving the background work running. Drain the stale value before sending (coalesceSend, shared by both stores); the exp test store mirrors the same logic inline since it cannot import localstore. --- go/ai/exp/localstore/file.go | 10 +++------- go/ai/exp/localstore/inmemory.go | 28 +++++++++++++++++++++------- go/ai/exp/teststore_test.go | 8 ++++++++ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index 855dc1aec2..64e5587574 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -330,15 +330,11 @@ func (s *FileSessionStore[State]) removeSub(snapshotID string, ch chan exp.Snaps } // notifyLocked publishes status to all live subscribers of snapshotID. -// Caller must hold s.mu. Sends are best-effort: a slow subscriber may miss -// intermediate values, but the latest value visible to the subscription is -// always one of the values persisted to disk. +// Caller must hold s.mu. A slow subscriber may miss intermediate values, but +// the latest value is always delivered (see [coalesceSend]). func (s *FileSessionStore[State]) notifyLocked(snapshotID string, status exp.SnapshotStatus) { for _, ch := range s.subs[snapshotID] { - select { - case ch <- status: - default: - } + coalesceSend(ch, status) } } diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go index 4fd179c6a3..7718abd607 100644 --- a/go/ai/exp/localstore/inmemory.go +++ b/go/ai/exp/localstore/inmemory.go @@ -206,15 +206,29 @@ func (s *InMemorySessionStore[State]) removeSub(snapshotID string, ch chan exp.S } // notifyLocked publishes status to all live subscribers of snapshotID. -// Caller must hold s.mu. Sends are best-effort: a slow subscriber may miss -// intermediate values, but the store guarantees the latest value visible -// to the subscription is the one persisted at notify time. +// Caller must hold s.mu. A slow subscriber may miss intermediate values, but +// the latest value is always delivered (see [coalesceSend]). func (s *InMemorySessionStore[State]) notifyLocked(snapshotID string, status exp.SnapshotStatus) { for _, ch := range s.subs[snapshotID] { - select { - case ch <- status: - default: - } + coalesceSend(ch, status) + } +} + +// coalesceSend delivers status on a size-1 buffered subscriber channel, +// guaranteeing the latest value stays observable even if an earlier value is +// still unread. Each channel is seeded at subscription time, so a plain +// non-blocking send would drop a newer status while the seed (or a prior +// status) sits in the buffer. Drop any stale unread value first; the caller +// holds the store mutex and is the only sender, so after the drain the send +// always has room. Shared by [InMemorySessionStore] and [FileSessionStore]. +func coalesceSend(ch chan exp.SnapshotStatus, status exp.SnapshotStatus) { + select { + case <-ch: + default: + } + select { + case ch <- status: + default: } } diff --git a/go/ai/exp/teststore_test.go b/go/ai/exp/teststore_test.go index 9778b970eb..94939303a3 100644 --- a/go/ai/exp/teststore_test.go +++ b/go/ai/exp/teststore_test.go @@ -163,7 +163,15 @@ func (s *testInMemStore[State]) removeSub(snapshotID string, ch chan SnapshotSta } func (s *testInMemStore[State]) notifyLocked(snapshotID string, status SnapshotStatus) { + // Coalesce to the latest status, mirroring localstore's coalesceSend: drop + // a stale unread value before sending so a newer status is never dropped + // while the seed sits in the size-1 buffer. (Can't share that helper here: + // localstore imports exp, so exp's test fixture can't import localstore.) for _, ch := range s.subs[snapshotID] { + select { + case <-ch: + default: + } select { case ch <- status: default: From b77a81b84e681785e19838e29ca58e9074fda28e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 9 Jun 2026 14:31:12 -0700 Subject: [PATCH 087/141] feat(go/exp): agent finish reasons on TurnEnd, AgentOutput, and snapshots Adds AgentFinishReason, mirroring the model-level enum (stop, length, blocked, interrupted, other, unknown) plus agent-only outcomes (aborted, detached, failed), threaded end to end: - SessionRunner.Run's per-turn callback now returns (*TurnResult, error) so a custom agent can report how each turn ended; nil reports nothing (no implicit inference). Prompt-backed agents forward the generate response's reason automatically, including "interrupted". - The reason rides the TurnEnd chunk, persists on the turn-end snapshot, and defaults the invocation's reason on AgentOutput; a custom agent can override the invocation reason via AgentResult.FinishReason. - Detach returns "detached" to the client; the finalizer later stamps the snapshot with how the background work actually ended (last turn's reason or override, "failed", or "aborted"). On abort the reason lands in a second finalizer write and may briefly lag status. - Snapshot dedup now skips a write only when both state and finish reason are unchanged, so an invocation-end override is never collapsed onto the turn-end row. Breaking for the experimental ai/exp package: the Run callback signature changed; all in-repo callers, the DefineCustomAgent doc example, and the basic-agents sample are updated. Schema changes originate in the zod source and are regenerated into genkit-schema.json, go/ai/exp/gen.go, and the Python typings. --- genkit-tools/common/src/types/agent.ts | 73 ++ genkit-tools/genkit-schema.json | 29 + go/ai/exp/agent.go | 157 +++- go/ai/exp/agent_test.go | 876 ++++++++++++++++-- go/ai/exp/gen.go | 75 ++ go/ai/exp/localstore/file_test.go | 37 + go/ai/exp/session.go | 11 +- go/core/schemas.config | 107 +++ go/genkit/genkit.go | 10 +- go/samples/basic-agents/cli.go | 8 +- go/samples/basic-agents/main.go | 13 +- .../genkit/src/genkit/_core/_typing.py | 19 + 12 files changed, 1294 insertions(+), 121 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 0164f4b1db..9643e0b247 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -73,6 +73,42 @@ export const SnapshotStatusSchema = z.enum([ ]); export type SnapshotStatus = z.infer; +/** + * Zod schema for the reason an agent turn or invocation finished. + * + * The first group mirrors the model-level `FinishReasonSchema` so a turn + * backed by a single `generate` call can forward its reason verbatim: + * + * - `stop`: the model stopped naturally. + * - `length`: generation hit the token limit. + * - `blocked`: generation was blocked (e.g. safety). + * - `interrupted`: the model paused on a tool request awaiting input + * (e.g. human approval); the turn can be resumed with a `resume` payload. + * - `other` / `unknown`: anything else / unspecified. + * + * The remaining values are agent-specific outcomes with no `generate`-level + * equivalent (they never arise from forwarding a model finish reason): + * + * - `aborted`: the turn or invocation was aborted (e.g. a detached + * invocation aborted via the `abortSnapshot` companion action). + * - `detached`: the invocation was moved to the background (the client + * detached). The returned output reports `detached`; the persisted + * snapshot is later finalized with how the background work actually ended. + * - `failed`: the turn or invocation terminated with an error. + */ +export const AgentFinishReasonSchema = z.enum([ + 'stop', + 'length', + 'blocked', + 'interrupted', + 'other', + 'unknown', + 'aborted', + 'detached', + 'failed', +]); +export type AgentFinishReason = z.infer; + /** * Zod schema for session state. */ @@ -131,6 +167,11 @@ export const AgentResultSchema = z.object({ message: MessageSchema.optional(), /** Artifacts produced during the session. */ artifacts: z.array(ArtifactSchema).optional(), + /** + * Why the invocation finished. Set by a custom agent to override the + * default (the last turn's reason); omitted to accept the default. + */ + finishReason: AgentFinishReasonSchema.optional(), }); export type AgentResult = z.infer; @@ -146,6 +187,12 @@ export const AgentOutputSchema = z.object({ message: MessageSchema.optional(), /** Artifacts produced during the session. */ artifacts: z.array(ArtifactSchema).optional(), + /** + * Why the invocation finished. `detached` when the client detached and + * the work continues in the background; otherwise the last turn's reason + * (or the value a custom agent set on its result). + */ + finishReason: AgentFinishReasonSchema.optional(), }); export type AgentOutput = z.infer; @@ -164,6 +211,12 @@ export const TurnEndSchema = z.object({ * snapshots were suspended after detach). */ snapshotId: z.string().optional(), + /** + * Why this turn finished (e.g. `stop`, `length`, `interrupted`). Lets a + * caller react to a turn boundary (e.g. pause on `interrupted`) without + * scanning the message content. Omitted when the turn reported no reason. + */ + finishReason: AgentFinishReasonSchema.optional(), }); export type TurnEnd = z.infer; @@ -204,6 +257,19 @@ export const SessionSnapshotSchema = z.object({ event: SnapshotEventSchema, /** Lifecycle state of this snapshot. Empty is treated as `succeeded`. */ status: SnapshotStatusSchema.optional(), + /** + * Semantic reason the turn or invocation captured here ended (e.g. + * `stop`, `interrupted`, `failed`, `aborted`). Complements `status` (the + * persistence lifecycle) so a resumed or background task can report how it + * ended without re-deriving it from the messages. + * + * On an aborted snapshot this is best-effort and may briefly lag `status`: + * `abortSnapshot` flips `status` to `aborted` immediately, while the + * `aborted` finish reason is stamped by a subsequent finalizer write. If + * the process running the invocation never finalizes, the reason can + * remain empty even though `status` is `aborted`. + */ + finishReason: AgentFinishReasonSchema.optional(), /** Structured failure information for a snapshot in `failed` status. */ error: z .object({ @@ -249,6 +315,13 @@ export const GetSnapshotResponseSchema = z.object({ updatedAt: z.string().optional(), /** Lifecycle state of the snapshot. */ status: SnapshotStatusSchema.optional(), + /** + * Semantic reason the captured turn or invocation ended (e.g. `stop`, + * `interrupted`, `failed`, `aborted`). Lets a remote or background poller + * report how a detached/resumed invocation ended without re-deriving it. + * Empty on a pending snapshot. + */ + finishReason: AgentFinishReasonSchema.optional(), /** Structured failure information; populated when status is `error`. */ error: z.any().optional(), /** diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index e77e5cf63a..c478cd7d65 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -28,6 +28,20 @@ ], "additionalProperties": false }, + "AgentFinishReason": { + "type": "string", + "enum": [ + "stop", + "length", + "blocked", + "interrupted", + "other", + "unknown", + "aborted", + "detached", + "failed" + ] + }, "AgentInit": { "type": "object", "properties": { @@ -103,6 +117,9 @@ "items": { "$ref": "#/$defs/Artifact" } + }, + "finishReason": { + "$ref": "#/$defs/AgentFinishReason" } }, "additionalProperties": false @@ -118,6 +135,9 @@ "items": { "$ref": "#/$defs/Artifact" } + }, + "finishReason": { + "$ref": "#/$defs/AgentFinishReason" } }, "additionalProperties": false @@ -194,6 +214,9 @@ "status": { "$ref": "#/$defs/SnapshotStatus" }, + "finishReason": { + "$ref": "#/$defs/AgentFinishReason" + }, "error": {}, "state": { "$ref": "#/$defs/SessionState" @@ -225,6 +248,9 @@ "status": { "$ref": "#/$defs/SnapshotStatus" }, + "finishReason": { + "$ref": "#/$defs/AgentFinishReason" + }, "error": { "type": "object", "properties": { @@ -293,6 +319,9 @@ "properties": { "snapshotId": { "type": "string" + }, + "finishReason": { + "$ref": "#/$defs/AgentFinishReason" } }, "additionalProperties": false diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index b76b8f4fc4..79b66642f4 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -59,6 +59,16 @@ type SessionRunner[State any] struct { lastSnapshotVersion uint64 collectTurnOutput func() any + // lastTurnFinishReason is the finish reason reported by the most recent + // turn (via the [TurnResult] its callback returned), or "" if the turn + // reported none. It is written by Run before [SessionRunner.onEndTurn] + // and read by the runtime when emitting the turn-end signal and when + // defaulting the invocation's finish reason. All accesses are confined + // to the fn goroutine (Run and its synchronous onEndTurn callback) until + // fn completes, after which the terminal paths read it with a + // happens-before edge through the fnDone channel, so no lock is needed. + lastTurnFinishReason AgentFinishReason + // intake is the source of truth for in-flight tracking, queue state, // and suspended state. The session consults it via beginTurnEnd (in // maybeSnapshot) so per-turn snapshot writes and detach captures @@ -76,11 +86,29 @@ func (s *SessionRunner[State]) parentSnapshotID() string { return s.lastSnapshot.SnapshotID } +// TurnResult is the optional return value of a [SessionRunner.Run] per-turn +// callback. It lets a custom agent report how the turn ended; the framework +// forwards the reason on the turn's [TurnEnd] chunk, persists it on the +// turn-end snapshot, and uses it to default the invocation's finish reason. +// +// Returning nil (or a zero TurnResult) omits the reason: the framework +// performs no implicit inference. A prompt-backed agent populates it +// automatically from the underlying generate response. +type TurnResult struct { + // FinishReason is why this turn ended (e.g. [AgentFinishReasonStop], + // [AgentFinishReasonInterrupted]). Empty to report no reason. + FinishReason AgentFinishReason +} + // Run loops over the input channel, calling fn for each turn. Each turn is // wrapped in a trace span for observability. Input messages are automatically // added to the session before fn is called. After fn returns successfully, a // TurnEnd chunk is sent and a snapshot check is triggered. -func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) error) error { +// +// fn may return a [TurnResult] to report how the turn ended (e.g. its finish +// reason); returning nil reports nothing. The reason rides the turn's +// [TurnEnd] chunk and is persisted on the turn-end snapshot. +func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) (*TurnResult, error)) error { for input := range s.InputCh { spanMeta := &tracing.SpanMetadata{ Name: fmt.Sprintf("agent/turn/%d", s.TurnIndex), @@ -92,9 +120,16 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont if input.Message != nil { s.AddMessages(input.Message) } - if err := fn(ctx, input); err != nil { + tr, err := fn(ctx, input) + if err != nil { return nil, err } + // Reset each turn; a returned TurnResult sets the reason, + // nil reports none. + s.lastTurnFinishReason = "" + if tr != nil { + s.lastTurnFinishReason = tr.FinishReason + } s.onEndTurn(ctx) s.TurnIndex++ if s.collectTurnOutput != nil { @@ -131,16 +166,28 @@ func (s *SessionRunner[State]) Result() *AgentResult { return result } +// invocationReason resolves the finish reason for the whole invocation: the +// last turn's reason, unless the agent's result overrides it with a non-empty +// one. Shared by the synchronous-completion and detach-finalize paths. +func (s *SessionRunner[State]) invocationReason(result *AgentResult) AgentFinishReason { + if result != nil && result.FinishReason != "" { + return result.FinishReason + } + return s.lastTurnFinishReason +} + // maybeSnapshot creates a snapshot if conditions are met (store configured, // callback approves, state changed, detach has not suspended snapshots). -// Returns the snapshot ID or empty string. +// Returns the snapshot ID or empty string. finishReason is recorded on the +// snapshot so a resumed or background task can report how the captured turn +// or invocation ended. // // For turn-end events, the session asks the intake whether snapshots // have been suspended (i.e. detach has landed). If so, the session skips // the turn-end snapshot — the pending row already captures the // invocation and a single finalize rewrite will record the cumulative // state once the queued inputs drain. -func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { +func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason) string { if event == SnapshotEventTurnEnd && s.intake != nil { if suspended := s.intake.beginTurnEnd(); suspended { return "" @@ -156,11 +203,16 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot currentState := s.copyStateLocked() s.mu.RUnlock() - // Skip if state hasn't changed since the last snapshot. This avoids - // redundant snapshots, e.g. the invocation-end snapshot after a - // single-turn Run where the turn-end snapshot already captured the - // same state. - if s.lastSnapshot != nil && currentVersion == s.lastSnapshotVersion { + // Skip only if this snapshot would be identical to the last one: same + // state AND same finish reason. This dedups the common invocation-end + // snapshot after a single-turn Run (the turn-end snapshot already + // captured the same state and reason), but still writes when the + // invocation reports a different reason than the last turn (e.g. a + // custom agent overrode it on its AgentResult) — that snapshot is not + // redundant, it carries a new reason. + if s.lastSnapshot != nil && + currentVersion == s.lastSnapshotVersion && + finishReason == s.lastSnapshot.FinishReason { return "" } @@ -184,10 +236,11 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot saved, err := s.store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { return &SessionSnapshot[State]{ - ParentID: parentID, - Event: event, - Status: SnapshotStatusSucceeded, - State: ¤tState, + ParentID: parentID, + Event: event, + Status: SnapshotStatusSucceeded, + FinishReason: finishReason, + State: ¤tState, }, nil }) if err != nil { @@ -439,9 +492,11 @@ func newAgentRuntime[Stream, State any]( // a turn-end snapshot (if applicable) and forwards the resulting [TurnEnd] // chunk through the router so clients see it on the output stream. func (rt *agentRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { - snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd) + reason := rt.sess.lastTurnFinishReason + snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd, reason) rt.router.sendChunk(ctx, &AgentStreamChunk[Stream]{TurnEnd: &TurnEnd{ - SnapshotID: snapshotID, + SnapshotID: snapshotID, + FinishReason: reason, }}) } @@ -573,14 +628,19 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( return nil, res.err } - snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + invocationReason := rt.sess.invocationReason(res.result) + snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventInvocationEnd, invocationReason) if snapshotID == "" && rt.sess.lastSnapshot != nil { - // State unchanged since the last turn-end snapshot — reuse it so - // the response always carries an ID when a store is configured. + // No new row was written; reuse the last snapshot so the response + // always carries an ID when a store is configured. On the dedup path + // the reused row is genuinely identical (same state and reason). If + // the snapshot callback declined the write or the save failed, the + // reused row is the last turn-end snapshot, whose reason (and state) + // may lag what this output reports. snapshotID = rt.sess.lastSnapshot.SnapshotID } - out := &AgentOutput[State]{SnapshotID: snapshotID} + out := &AgentOutput[State]{SnapshotID: snapshotID, FinishReason: invocationReason} if res.result != nil { // Deep-copy at the framework boundary so the caller cannot // mutate session contents through the returned output, even @@ -657,11 +717,17 @@ func (rt *agentRuntime[Stream, State]) handleDetach( stopSub() rt.intake.stopAndWait() rt.router.close() - rt.finalizePendingSnapshot(finalizeCtx, pending, res.err, abortedByUser.Load()) + rt.finalizePendingSnapshot(finalizeCtx, pending, res.result, res.err, abortedByUser.Load()) cancelWork() }() - return &AgentOutput[State]{SnapshotID: pending.SnapshotID}, nil + // The invocation, from the client's perspective, ended by detaching. The + // pending snapshot is finalized later with how the background work + // actually ended (see finalizePendingSnapshot). + return &AgentOutput[State]{ + SnapshotID: pending.SnapshotID, + FinishReason: AgentFinishReasonDetached, + }, nil } // finalizePendingSnapshot rewrites the pending snapshot row with the @@ -674,37 +740,58 @@ func (rt *agentRuntime[Stream, State]) handleDetach( func (rt *agentRuntime[Stream, State]) finalizePendingSnapshot( ctx context.Context, pending *SessionSnapshot[State], + result *AgentResult, fnErr error, abortedByUser bool, ) { finalState := *rt.session.State() + // Captured outside the SaveSnapshot callback (which must stay pure): the + // finalizer runs after fn returned, so this is stable. The abort/error + // branches below own their reasons and ignore this clean-success default. + succeededReason := rt.sess.invocationReason(result) _, err := rt.cfg.store.SaveSnapshot(ctx, pending.SnapshotID, func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error) { - // Late abort wins over the terminal we were about to land. + // Late abort wins over the terminal we were about to land: keep + // the aborted status and whatever state the abort left, but + // stamp the aborted finish reason so the snapshot is + // self-describing. (AbortSnapshot only flips status; the runtime + // owns the semantic reason.) Skip the write once already stamped. if existing != nil && existing.Status == SnapshotStatusAborted { - return nil, nil + if existing.FinishReason == AgentFinishReasonAborted { + return nil, nil + } + annotated := *existing + annotated.FinishReason = AgentFinishReasonAborted + return &annotated, nil } status := SnapshotStatusSucceeded + // The persisted finish reason records how the background work + // actually ended, distinct from the detached reason the client + // already saw on AgentOutput. + finishReason := succeededReason var snapErr *core.GenkitError switch { case abortedByUser: status = SnapshotStatusAborted + finishReason = AgentFinishReasonAborted if fnErr != nil { snapErr = core.AsGenkitError(fnErr) // aborted wins, preserve text } case fnErr != nil: status = SnapshotStatusFailed + finishReason = AgentFinishReasonFailed snapErr = core.AsGenkitError(fnErr) } return &SessionSnapshot[State]{ - ParentID: pending.ParentID, - Event: SnapshotEventDetach, - Status: status, - Error: snapErr, - State: &finalState, + ParentID: pending.ParentID, + Event: SnapshotEventDetach, + Status: status, + FinishReason: finishReason, + Error: snapErr, + State: &finalState, }, nil }) if err != nil { @@ -1218,14 +1305,14 @@ func validateUserMessage(m *ai.Message) error { // input. func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[any, State] { return func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*AgentResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if err := validateUserMessage(input.Message); err != nil { - return err + return nil, err } actionOpts, err := prompt.Render(ctx, defaultInput) if err != nil { - return fmt.Errorf("prompt render: %w", err) + return nil, fmt.Errorf("prompt render: %w", err) } // Tag base messages so they can be filtered out of session @@ -1257,7 +1344,7 @@ func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) Ag }, ) if err != nil { - return fmt.Errorf("generate: %w", err) + return nil, fmt.Errorf("generate: %w", err) } // Replace session messages with the full history minus base @@ -1288,7 +1375,11 @@ func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) Ag } } - return nil + // Forward the generate response's finish reason verbatim: the + // agent enum is a superset of the model enum for these values, + // so the turn (and a single-turn invocation) reports e.g. + // "interrupted" without the client scanning message content. + return &TurnResult{FinishReason: AgentFinishReason(modelResp.FinishReason)}, nil }); err != nil { return nil, err } diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index ea2e269339..b74241e130 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -50,7 +50,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { af := DefineCustomAgent(reg, "basicFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. if input.Message != nil { @@ -62,7 +62,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { return s }) resp.SendStatus(testStatus{Phase: "complete"}) - return nil + return nil, nil }) }, ) @@ -126,7 +126,7 @@ func TestAgent_WithSessionStore(t *testing.T) { af := DefineCustomAgent(reg, "snapshotFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -134,7 +134,7 @@ func TestAgent_WithSessionStore(t *testing.T) { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -196,7 +196,7 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { af := DefineCustomAgent(reg, "resumeFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -204,7 +204,7 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -285,7 +285,7 @@ func TestAgent_ClientManagedState(t *testing.T) { af := DefineCustomAgent(reg, "clientStateFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } @@ -293,7 +293,7 @@ func TestAgent_ClientManagedState(t *testing.T) { s.Counter++ return s }) - return nil + return nil, nil }) }, ) @@ -348,7 +348,7 @@ func TestAgent_Artifacts(t *testing.T) { af := DefineCustomAgent(reg, "artifactFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendArtifact(&Artifact{ Name: "code.go", @@ -368,7 +368,7 @@ func TestAgent_Artifacts(t *testing.T) { }) sess.AddMessages(ai.NewModelTextMessage("done")) - return nil + return nil, nil }) if err != nil { return nil, err @@ -421,13 +421,13 @@ func TestAgent_SnapshotCallback(t *testing.T) { callbackCalls := 0 af := DefineCustomAgent(reg, "callbackFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -477,8 +477,8 @@ func TestAgent_SendMessage(t *testing.T) { af := DefineCustomAgent(reg, "sendMsgFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - return nil + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, nil }) }, ) @@ -521,19 +521,19 @@ func TestAgent_SessionContext(t *testing.T) { var retrievedCounter int af := DefineCustomAgent(reg, "contextFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { // Session should be retrievable from context. ctxSess := SessionFromContext[testState](ctx) if ctxSess == nil { t.Error("expected session from context") - return nil + return nil, nil } ctxSess.UpdateCustom(func(s testState) testState { s.Counter = 42 return s }) retrievedCounter = ctxSess.Custom().Counter - return nil + return nil, nil }) }, ) @@ -566,8 +566,8 @@ func TestAgent_ErrorInTurn(t *testing.T) { af := DefineCustomAgent(reg, "errorFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - return fmt.Errorf("turn failed") + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, fmt.Errorf("turn failed") }) }, ) @@ -592,10 +592,10 @@ func TestAgent_SetMessages(t *testing.T) { af := DefineCustomAgent(reg, "setMsgsFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { // Replace all messages with just one. sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) - return nil + return nil, nil }) }, ) @@ -643,7 +643,7 @@ func TestAgent_TurnSpanOutput(t *testing.T) { return output } - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendStatus(testStatus{Phase: "thinking"}) resp.SendModelChunk(&ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("reply")}, @@ -653,7 +653,7 @@ func TestAgent_TurnSpanOutput(t *testing.T) { Parts: []*ai.Part{ai.NewTextPart("content")}, }) sess.AddMessages(ai.NewModelTextMessage("reply")) - return nil + return nil, nil }) }, ) @@ -721,10 +721,10 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { return output } - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendStatus(testStatus{Phase: "working"}) sess.AddMessages(ai.NewModelTextMessage("reply")) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -1203,7 +1203,7 @@ func TestAgent_RunText(t *testing.T) { af := DefineCustomAgent(reg, "runTextFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Message.Content[0].Text)) } @@ -1211,7 +1211,7 @@ func TestAgent_RunText(t *testing.T) { s.Counter++ return s }) - return nil + return nil, nil }) }, ) @@ -1236,11 +1236,11 @@ func TestAgent_Run(t *testing.T) { af := DefineCustomAgent(reg, "runFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) } - return nil + return nil, nil }) }, ) @@ -1266,13 +1266,13 @@ func TestAgent_RunText_WithState(t *testing.T) { af := DefineCustomAgent(reg, "runStateFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) - return nil + return nil, nil }) }, ) @@ -1307,13 +1307,13 @@ func TestAgent_RunText_WithSnapshot(t *testing.T) { af := DefineCustomAgent(reg, "runSnapshotFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -1447,13 +1447,13 @@ func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { af := DefineCustomAgent(reg, "dedupFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -1491,13 +1491,13 @@ func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { af := DefineCustomAgent(reg, "multiDedupFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -1553,9 +1553,9 @@ func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { af := DefineCustomAgent(reg, "postRunMutateFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) - return nil + return nil, nil }); err != nil { return nil, err } @@ -1606,7 +1606,7 @@ func TestAgent_FnPanicReturnsError(t *testing.T) { af := DefineCustomAgent(reg, "panicFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendStatus(testStatus{Phase: "before-panic"}) panic("boom") }) @@ -1661,7 +1661,7 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { af := DefineCustomAgent(reg, "cancelFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { defer close(fnDone) - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { close(emitting) // Emit until ctx cancels. Without the goroutine's // ctx-aware drain, this would deadlock once the consumer @@ -1669,7 +1669,7 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { for { select { case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() default: } resp.SendStatus(testStatus{Phase: "tick"}) @@ -1724,6 +1724,25 @@ func waitForSnapshot[State any]( return nil } +// nextTurnEnd consumes conn's stream until the next TurnEnd chunk and returns +// a copy of it, failing the test if the stream errors or ends first. Use it +// in tests that only need to advance to a turn boundary; tests that must +// inspect intermediate chunks should range over Receive directly. +func nextTurnEnd[Stream, State any](t *testing.T, conn *AgentConnection[Stream, State]) *TurnEnd { + t.Helper() + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.TurnEnd != nil { + te := *chunk.TurnEnd + return &te + } + } + t.Fatal("no TurnEnd chunk observed") + return nil +} + func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { // Sanity: each TurnEnd chunk carries the snapshot ID of the turn-end // snapshot, and the snapshots themselves are persisted. @@ -1732,9 +1751,9 @@ func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { af := DefineCustomAgent(reg, "turnEndSnapshotFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -1798,7 +1817,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { af := DefineCustomAgent(reg, "detachInFlight", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { entered <- struct{}{} <-release sess.AddMessages(ai.NewModelTextMessage("reply-" + input.Message.Text())) @@ -1806,7 +1825,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -1897,11 +1916,11 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { af := DefineCustomAgent(reg, "detachChainParent", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { enter <- struct{}{} <-release sess.AddMessages(ai.NewModelTextMessage("ok")) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -1975,8 +1994,8 @@ func TestAgent_Detach_RequiresStore(t *testing.T) { af := DefineCustomAgent(reg, "detachNoStore", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { - return nil + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, nil }) }, ) @@ -2010,7 +2029,7 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { af := DefineCustomAgent(reg, "detachComplete", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: case <-ctx.Done(): @@ -2021,7 +2040,7 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { s.Counter = 42 return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -2105,7 +2124,7 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { af := DefineCustomAgent(reg, "detachArtifact", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendArtifact(&Artifact{ Name: "before.txt", Parts: []*ai.Part{ai.NewTextPart("pre-detach")}, @@ -2113,14 +2132,14 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { select { case <-detached: case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() } resp.SendArtifact(&Artifact{ Name: "after.txt", Parts: []*ai.Part{ai.NewTextPart("post-detach")}, }) <-release - return nil + return nil, nil }) }, WithSessionStore(store), @@ -2182,13 +2201,13 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { af := DefineCustomAgent(reg, "detachErr", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: case <-time.After(time.Second): } <-release - return boom + return nil, boom }) }, WithSessionStore(store), @@ -2252,13 +2271,13 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { af := DefineCustomAgent(reg, "detachAbort", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: case <-time.After(time.Second): } <-ctx.Done() - return ctx.Err() + return nil, ctx.Err() }) }, WithSessionStore(store), @@ -2323,9 +2342,9 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { af := DefineCustomAgent(reg, "syncStillWorks", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -2381,13 +2400,13 @@ func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { af := DefineCustomAgent(reg, "syncCancel", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: case <-ctx.Done(): } <-ctx.Done() - return ctx.Err() + return nil, ctx.Err() }) exited <- err return nil, err @@ -2478,9 +2497,9 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { af := DefineCustomAgent(reg, "transformedFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("the secret is out")) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -2541,6 +2560,44 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } } +// TestAgent_GetSnapshotAction_ReturnsFinishReason verifies the remote +// getSnapshot companion action surfaces the persisted finish reason, so a +// non-Go client or the Dev UI polling a detached/background invocation can +// report how it ended without re-deriving it from the messages. +func TestAgent_GetSnapshotAction_ReturnsFinishReason(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := DefineCustomAgent(reg, "finishReasonActionFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("done")) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }) + }, + WithSessionStore(store), + ) + + ctx := context.Background() + out, err := af.RunText(ctx, "hi") + if err != nil { + t.Fatalf("RunText: %v", err) + } + + action := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + reg, api.ActionTypeAgentSnapshot, "finishReasonActionFlow") + if action == nil { + t.Fatal("getSnapshot action not registered") + } + resp, err := action.Run(ctx, &GetSnapshotRequest{SnapshotID: out.SnapshotID}, nil) + if err != nil { + t.Fatalf("getSnapshot action: %v", err) + } + if resp.FinishReason != AgentFinishReasonStop { + t.Errorf("GetSnapshotResponse.FinishReason = %q, want %q", resp.FinishReason, AgentFinishReasonStop) + } +} + func TestInMemorySessionStore_GetSnapshot_NotFound(t *testing.T) { store := newTestInMemStore[testState]() @@ -2832,12 +2889,12 @@ func TestAgent_StateTransform_ClientManagedState(t *testing.T) { af := DefineCustomAgent(reg, "clientXformFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.UpdateCustom(func(s testState) testState { s.Counter = 7 return s }) - return nil + return nil, nil }) }, WithStateTransform[testState](transform), @@ -2863,13 +2920,13 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { af := DefineCustomAgent(reg, "resumeDetachedFlow", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -3008,7 +3065,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { af := DefineCustomAgent(reg, "raceFinalize", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: case <-time.After(time.Second): @@ -3017,7 +3074,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { // Return cleanly without observing ctx. Without the // subscriber/recheck, this would land status=succeeded and // clobber the abort. - return nil + return nil, nil }) }, WithSessionStore(store), @@ -3140,9 +3197,9 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { af := DefineCustomAgent(reg, "abortNoop", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) - return nil + return nil, nil }) }, WithSessionStore(store), @@ -3191,13 +3248,13 @@ func TestAgent_ResultAndOutput_IsolatedFromSession(t *testing.T) { af := DefineCustomAgent(reg, "isolation", func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) error { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("session-msg")) sess.AddArtifacts(&Artifact{ Name: "orig", Parts: []*ai.Part{ai.NewTextPart("orig-part")}, }) - return nil + return nil, nil }); err != nil { return nil, err } @@ -3271,3 +3328,672 @@ func TestAgent_Name(t *testing.T) { t.Errorf("Name() = %q, want %q", got, "name-accessor") } } + +// --- Finish reasons --------------------------------------------------------- + +// TestAgent_FinishReason_TurnAndInvocation verifies that the reason a custom +// agent reports per turn (via TurnResult) rides the TurnEnd chunk, is +// persisted on the turn-end snapshot, and defaults the invocation's reason on +// AgentOutput. +func TestAgent_FinishReason_TurnAndInvocation(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := DefineCustomAgent(reg, "finishReasonFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + + turnEnd := nextTurnEnd(t, conn) + if turnEnd.FinishReason != AgentFinishReasonStop { + t.Errorf("TurnEnd.FinishReason = %q, want %q", turnEnd.FinishReason, AgentFinishReasonStop) + } + + // The turn-end snapshot records the reason. + snap, err := store.GetSnapshot(context.Background(), turnEnd.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap == nil { + t.Fatalf("turn-end snapshot %q missing", turnEnd.SnapshotID) + } + if snap.FinishReason != AgentFinishReasonStop { + t.Errorf("snapshot.FinishReason = %q, want %q", snap.FinishReason, AgentFinishReasonStop) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonStop { + t.Errorf("AgentOutput.FinishReason = %q, want %q (defaulted from last turn)", out.FinishReason, AgentFinishReasonStop) + } +} + +// TestAgent_FinishReason_OmittedWhenNil verifies that returning a nil +// TurnResult performs no implicit inference: the reason is omitted on both the +// turn-end signal and the invocation output. +func TestAgent_FinishReason_OmittedWhenNil(t *testing.T) { + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "noReasonFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + turnEnd := nextTurnEnd(t, conn) + if turnEnd.FinishReason != "" { + t.Errorf("TurnEnd.FinishReason = %q, want empty", turnEnd.FinishReason) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != "" { + t.Errorf("AgentOutput.FinishReason = %q, want empty", out.FinishReason) + } +} + +// TestAgent_FinishReason_InvocationOverride verifies that a custom agent can +// override the invocation's finish reason via AgentResult.FinishReason without +// affecting the per-turn reason on TurnEnd. +func TestAgent_FinishReason_InvocationOverride(t *testing.T) { + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "overrideReasonFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }); err != nil { + return nil, err + } + res := sess.Result() + res.FinishReason = AgentFinishReasonOther + return res, nil + }, + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + turnEnd := nextTurnEnd(t, conn) + if turnEnd.FinishReason != AgentFinishReasonStop { + t.Errorf("TurnEnd.FinishReason = %q, want %q (per-turn, unaffected by override)", turnEnd.FinishReason, AgentFinishReasonStop) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonOther { + t.Errorf("AgentOutput.FinishReason = %q, want %q (override)", out.FinishReason, AgentFinishReasonOther) + } +} + +// TestAgent_FinishReason_MultiTurnDistinct verifies that each turn's TurnEnd +// carries that turn's own reason, and the invocation defaults to the last +// turn's reason. +func TestAgent_FinishReason_MultiTurnDistinct(t *testing.T) { + reg := newTestRegistry(t) + + // Turn 0 reports "stop"; turn 1 reports "interrupted". + reasons := []AgentFinishReason{AgentFinishReasonStop, AgentFinishReasonInterrupted} + + af := DefineCustomAgent(reg, "multiReasonFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return &TurnResult{FinishReason: reasons[sess.TurnIndex]}, nil + }) + }, + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + var got []AgentFinishReason + for i := 0; i < len(reasons); i++ { + if err := conn.SendText("turn"); err != nil { + t.Fatalf("SendText: %v", err) + } + got = append(got, nextTurnEnd(t, conn).FinishReason) + } + for i, want := range reasons { + if got[i] != want { + t.Errorf("turn %d TurnEnd.FinishReason = %q, want %q", i, got[i], want) + } + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonInterrupted { + t.Errorf("AgentOutput.FinishReason = %q, want %q (last turn)", out.FinishReason, AgentFinishReasonInterrupted) + } +} + +// TestPromptAgent_ForwardsFinishReason verifies that a prompt-backed agent +// forwards the underlying generate response's finish reason automatically. +func TestPromptAgent_ForwardsFinishReason(t *testing.T) { + ctx := context.Background() + reg := registry.New() + ai.ConfigureFormats(reg) + ai.DefineModel(reg, "test/length", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return &ai.ModelResponse{ + Request: req, + Message: ai.NewModelTextMessage("partial"), + FinishReason: ai.FinishReasonLength, + }, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + ai.DefinePrompt(reg, "lengthPrompt", ai.WithModelName("test/length")) + + af := DefineAgent[testState](reg, "lengthPrompt", FromPrompt()) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + turnEnd := nextTurnEnd(t, conn) + if turnEnd.FinishReason != AgentFinishReasonLength { + t.Errorf("TurnEnd.FinishReason = %q, want %q", turnEnd.FinishReason, AgentFinishReasonLength) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonLength { + t.Errorf("AgentOutput.FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonLength) + } +} + +// TestAgent_Detach_FinishReasons covers the three detach outcomes: the output +// returned to the detaching client always reports "detached", while the +// persisted snapshot records how the background work actually ended +// (succeeded -> last turn's reason, failed, or aborted). +func TestAgent_Detach_FinishReasons(t *testing.T) { + t.Run("complete", func(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + release := make(chan struct{}) + entered := make(chan struct{}) + + af := DefineCustomAgent(reg, "detachReasonComplete", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + select { + case entered <- struct{}{}: + case <-ctx.Done(): + } + <-release + sess.AddMessages(ai.NewModelTextMessage("done")) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonDetached { + t.Errorf("AgentOutput.FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonDetached) + } + + close(release) + snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusSucceeded + }) + if snap.FinishReason != AgentFinishReasonStop { + t.Errorf("finalized snapshot.FinishReason = %q, want %q", snap.FinishReason, AgentFinishReasonStop) + } + }) + + t.Run("failed", func(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + release := make(chan struct{}) + entered := make(chan struct{}) + + af := DefineCustomAgent(reg, "detachReasonFailed", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + select { + case entered <- struct{}{}: + case <-time.After(time.Second): + } + <-release + return nil, errors.New("kaboom") + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonDetached { + t.Errorf("AgentOutput.FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonDetached) + } + + close(release) + snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusFailed + }) + if snap.FinishReason != AgentFinishReasonFailed { + t.Errorf("finalized snapshot.FinishReason = %q, want %q", snap.FinishReason, AgentFinishReasonFailed) + } + }) + + t.Run("aborted", func(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + entered := make(chan struct{}) + + af := DefineCustomAgent(reg, "detachReasonAborted", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + select { + case entered <- struct{}{}: + case <-time.After(time.Second): + } + <-ctx.Done() + return nil, ctx.Err() + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if _, err := store.AbortSnapshot(context.Background(), out.SnapshotID); err != nil { + t.Fatalf("AbortSnapshot: %v", err) + } + // AbortSnapshot flips status=aborted (finishReason still empty); the + // finalizer then annotates the row with finishReason=aborted. Wait + // for that second write rather than the bare status flip. + snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusAborted && s.FinishReason == AgentFinishReasonAborted + }) + if snap.Status != SnapshotStatusAborted { + t.Errorf("finalized snapshot.Status = %q, want %q", snap.Status, SnapshotStatusAborted) + } + }) +} + +// TestAgent_FinishReason_InvocationOverride_Persisted verifies that when a +// custom agent overrides the invocation reason (differing from the last +// turn's), the dedup does not collapse it onto the turn-end snapshot: a +// distinct invocation-end snapshot is written carrying the override, so the +// snapshot AgentOutput points at agrees with AgentOutput.FinishReason. +func TestAgent_FinishReason_InvocationOverride_Persisted(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := DefineCustomAgent(reg, "overridePersistedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }); err != nil { + return nil, err + } + res := sess.Result() + res.FinishReason = AgentFinishReasonOther + return res, nil + }, + WithSessionStore(store), + ) + + ctx := context.Background() + out, err := af.RunText(ctx, "hi") + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.FinishReason != AgentFinishReasonOther { + t.Fatalf("AgentOutput.FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonOther) + } + if out.SnapshotID == "" { + t.Fatal("expected a snapshot ID") + } + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.FinishReason != out.FinishReason { + t.Errorf("persisted snapshot.FinishReason = %q, want %q (must agree with AgentOutput)", snap.FinishReason, out.FinishReason) + } + // The override must be a fresh invocation-end snapshot, not the turn-end + // row mutated or reused: the divergent reason busts the dedup. + if snap.Event != SnapshotEventInvocationEnd { + t.Errorf("snapshot.Event = %q, want %q (a distinct invocation-end snapshot)", snap.Event, SnapshotEventInvocationEnd) + } +} + +// TestAgent_FinishReason_MultiTurnDistinct_Persisted verifies each turn-end +// snapshot persists that turn's own reason and chains to its parent. +func TestAgent_FinishReason_MultiTurnDistinct_Persisted(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + reasons := []AgentFinishReason{AgentFinishReasonStop, AgentFinishReasonInterrupted} + + af := DefineCustomAgent(reg, "multiReasonPersistedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return &TurnResult{FinishReason: reasons[sess.TurnIndex]}, nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + var ids []string + for i := 0; i < len(reasons); i++ { + if err := conn.SendText("turn"); err != nil { + t.Fatalf("SendText: %v", err) + } + te := nextTurnEnd(t, conn) + if te.FinishReason != reasons[i] { + t.Errorf("turn %d TurnEnd.FinishReason = %q, want %q", i, te.FinishReason, reasons[i]) + } + ids = append(ids, te.SnapshotID) + } + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + + for i, id := range ids { + snap, err := store.GetSnapshot(context.Background(), id) + if err != nil { + t.Fatalf("GetSnapshot[%d]: %v", i, err) + } + if snap.FinishReason != reasons[i] { + t.Errorf("snapshot[%d].FinishReason = %q, want %q", i, snap.FinishReason, reasons[i]) + } + } + // The second turn's snapshot chains to the first. + snap1, _ := store.GetSnapshot(context.Background(), ids[1]) + if snap1.ParentID != ids[0] { + t.Errorf("snapshot[1].ParentID = %q, want %q", snap1.ParentID, ids[0]) + } +} + +// TestAgent_FinishReason_OmittedPersisted verifies a turn that reports no +// reason persists an empty finishReason (no implicit inference at rest). +func TestAgent_FinishReason_OmittedPersisted(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := DefineCustomAgent(reg, "noReasonPersistedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("ok")) + return nil, nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + snapID := nextTurnEnd(t, conn).SnapshotID + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + snap, err := store.GetSnapshot(context.Background(), snapID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.FinishReason != "" { + t.Errorf("snapshot.FinishReason = %q, want empty", snap.FinishReason) + } +} + +// TestPromptAgent_ForwardsInterruptedFinishReason drives a real interrupted +// generate response through the prompt-backed loop: the turn must stream the +// interrupt parts AND report finishReason=interrupted (the proposal's +// motivating case), without the client scanning message content. +func TestPromptAgent_ForwardsInterruptedFinishReason(t *testing.T) { + ctx := context.Background() + reg := registry.New() + ai.ConfigureFormats(reg) + + interruptTool := ai.DefineTool(reg, "interruptor", "always interrupts", + func(tc *ai.ToolContext, input any) (any, error) { + return nil, tc.Interrupt(&ai.InterruptOptions{ + Metadata: map[string]any{"reason": "needs approval"}, + }) + }, + ) + ai.DefineModel(reg, "test/interrupt", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, Tools: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return &ai.ModelResponse{ + Request: req, + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{Name: "interruptor"})}, + }, + }, nil + }) + ai.DefineGenerateAction(ctx, reg) + ai.DefinePrompt(reg, "interruptPrompt", + ai.WithModelName("test/interrupt"), + ai.WithTools(interruptTool), + ) + + af := DefineAgent[testState](reg, "interruptPrompt", FromPrompt()) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("do it"); err != nil { + t.Fatalf("SendText: %v", err) + } + var ( + turnEnd *TurnEnd + gotToolChunk bool + ) + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.ModelChunk != nil && chunk.ModelChunk.Role == ai.RoleTool { + gotToolChunk = true + } + if chunk.TurnEnd != nil { + te := *chunk.TurnEnd + turnEnd = &te + break + } + } + if !gotToolChunk { + t.Error("expected a tool-role chunk carrying the interrupt parts") + } + if turnEnd == nil || turnEnd.FinishReason != AgentFinishReasonInterrupted { + t.Fatalf("TurnEnd.FinishReason = %v, want %q", turnEnd, AgentFinishReasonInterrupted) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonInterrupted { + t.Errorf("AgentOutput.FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonInterrupted) + } +} + +// TestAgent_Detach_SucceededHonorsResultOverride verifies the detach finalizer +// applies an AgentResult.FinishReason override on clean success, matching the +// synchronous path (the override does not leak into the failed/aborted cases, +// which are covered by TestAgent_Detach_FinishReasons). +func TestAgent_Detach_SucceededHonorsResultOverride(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + release := make(chan struct{}) + entered := make(chan struct{}) + + af := DefineCustomAgent(reg, "detachOverride", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + select { + case entered <- struct{}{}: + case <-ctx.Done(): + } + <-release + sess.AddMessages(ai.NewModelTextMessage("done")) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }); err != nil { + return nil, err + } + res := sess.Result() + res.FinishReason = AgentFinishReasonOther + return res, nil + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonDetached { + t.Errorf("AgentOutput.FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonDetached) + } + + close(release) + snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusSucceeded + }) + if snap.FinishReason != AgentFinishReasonOther { + t.Errorf("finalized snapshot.FinishReason = %q, want %q (AgentResult override)", snap.FinishReason, AgentFinishReasonOther) + } +} diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 59427361ba..261ce2c7c4 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -41,6 +41,48 @@ type AbortSnapshotResponse struct { Status SnapshotStatus `json:"status,omitempty"` } +// AgentFinishReason is why an agent turn or invocation finished. +// +// The first group mirrors the model-level [ai.FinishReason] so a turn backed +// by a single generate call can forward its reason verbatim: +// [AgentFinishReasonStop], [AgentFinishReasonLength], +// [AgentFinishReasonBlocked], [AgentFinishReasonInterrupted], +// [AgentFinishReasonOther] and [AgentFinishReasonUnknown]. +// +// The remaining values are agent-specific outcomes with no generate-level +// equivalent (they never arise from forwarding a model finish reason): +// [AgentFinishReasonAborted], [AgentFinishReasonDetached] and +// [AgentFinishReasonFailed]. +type AgentFinishReason string + +const ( + // AgentFinishReasonStop indicates the model stopped naturally. + AgentFinishReasonStop AgentFinishReason = "stop" + // AgentFinishReasonLength indicates generation hit the token limit. + AgentFinishReasonLength AgentFinishReason = "length" + // AgentFinishReasonBlocked indicates generation was blocked (e.g. safety). + AgentFinishReasonBlocked AgentFinishReason = "blocked" + // AgentFinishReasonInterrupted indicates the model paused on a tool request + // awaiting input (e.g. human approval). The turn can be resumed by sending an + // [AgentInput] with a Resume payload. + AgentFinishReasonInterrupted AgentFinishReason = "interrupted" + // AgentFinishReasonOther indicates generation stopped for some other reason. + AgentFinishReasonOther AgentFinishReason = "other" + // AgentFinishReasonUnknown indicates the reason was unspecified. + AgentFinishReasonUnknown AgentFinishReason = "unknown" + // AgentFinishReasonAborted indicates the turn or invocation was aborted + // (e.g. a detached invocation aborted via the abortSnapshot companion action). + AgentFinishReasonAborted AgentFinishReason = "aborted" + // AgentFinishReasonDetached indicates the invocation was moved to the + // background because the client detached. The returned [AgentOutput] reports + // this reason; the persisted snapshot is later finalized with how the + // background work actually ended. + AgentFinishReasonDetached AgentFinishReason = "detached" + // AgentFinishReasonFailed indicates the turn or invocation terminated with an + // error. + AgentFinishReasonFailed AgentFinishReason = "failed" +) + // AgentInit is the input for starting an agent invocation. // Exactly one of SnapshotID or State may be set, and the choice must match // the agent's state management: @@ -108,6 +150,11 @@ type AgentMetadata struct { type AgentOutput[State any] struct { // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` + // FinishReason is why the invocation finished. It is + // [AgentFinishReasonDetached] when the client detached and the work continues + // in the background; otherwise it is the last turn's reason (or the value a + // custom agent set on its [AgentResult]). + FinishReason AgentFinishReason `json:"finishReason,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` // SnapshotID is the ID of the snapshot created at the end of this invocation. @@ -123,6 +170,10 @@ type AgentOutput[State any] struct { type AgentResult struct { // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` + // FinishReason is why the invocation finished. A custom agent may set it to + // override the default (the last turn's reason); leave it empty to accept the + // default. + FinishReason AgentFinishReason `json:"finishReason,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` } @@ -193,6 +244,12 @@ type GetSnapshotResponse[State any] struct { // Error is the structured failure information; populated when Status // is [SnapshotStatusFailed]. Error *core.GenkitError `json:"error,omitempty"` + // FinishReason is the semantic reason the captured turn or invocation ended + // (e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], + // [AgentFinishReasonFailed], [AgentFinishReasonAborted]). It lets a remote or + // background poller report how a detached or resumed invocation ended without + // re-deriving it. Empty on a pending snapshot. + FinishReason AgentFinishReason `json:"finishReason,omitempty"` // SnapshotID echoes the requested snapshot ID. SnapshotID string `json:"snapshotId"` // State is the session state captured by the snapshot, after any @@ -215,6 +272,19 @@ type SessionSnapshot[State any] struct { Error *core.GenkitError `json:"error,omitempty"` // Event is what triggered this snapshot. Event SnapshotEvent `json:"event"` + // FinishReason is the semantic reason the turn or invocation captured here + // ended (e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], + // [AgentFinishReasonFailed], [AgentFinishReasonAborted]). It complements + // [SessionSnapshot.Status] (the persistence lifecycle) so a resumed or + // background task can report how it ended without re-deriving it from the + // messages. + // + // On an aborted snapshot this is best-effort and may briefly lag Status: + // AbortSnapshot flips Status to [SnapshotStatusAborted] immediately, while the + // [AgentFinishReasonAborted] finish reason is stamped by a subsequent finalizer + // write. If the process running the invocation never finalizes, the reason can + // remain empty even though Status is [SnapshotStatusAborted]. + FinishReason AgentFinishReason `json:"finishReason,omitempty"` // ParentID is the ID of the previous snapshot in this timeline. ParentID string `json:"parentId,omitempty"` // SnapshotID is the unique identifier for this snapshot (UUID). @@ -294,6 +364,11 @@ const ( // A TurnEnd value is emitted exactly once per turn, regardless of whether a // snapshot was persisted. type TurnEnd struct { + // FinishReason is why this turn finished (e.g. [AgentFinishReasonStop], + // [AgentFinishReasonInterrupted]). It lets a caller react to a turn boundary + // (such as pausing on an interrupt) without scanning the message content. + // Empty when the turn reported no reason. + FinishReason AgentFinishReason `json:"finishReason,omitempty"` // SnapshotID is the ID of the snapshot persisted at the end of this turn. // Empty if no snapshot was created (callback returned false or no store // configured, or snapshots were suspended after detach). diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index 327efdf3a4..a5b5cd02cb 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -476,3 +476,40 @@ func TestFileSessionStore(t *testing.T) { } }) } + +// TestFileSessionStore_FinishReasonPersistsAcrossReopen verifies that a +// snapshot's finish reason survives the disk round-trip: a second store +// opened on the same directory (as after a process restart) reads it back. +func TestFileSessionStore_FinishReasonPersistsAcrossReopen(t *testing.T) { + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + saved, err := store.SaveSnapshot(context.Background(), "", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + Status: exp.SnapshotStatusSucceeded, + FinishReason: exp.AgentFinishReasonInterrupted, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + + reopened, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("reopen NewFileSessionStore: %v", err) + } + got, err := reopened.GetSnapshot(context.Background(), saved.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if got == nil { + t.Fatalf("snapshot %q missing after reopen", saved.SnapshotID) + } + if got.FinishReason != exp.AgentFinishReasonInterrupted { + t.Errorf("FinishReason = %q, want %q", got.FinishReason, exp.AgentFinishReasonInterrupted) + } +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 48ebdcf280..21a66d2dd6 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -234,11 +234,12 @@ func registerSnapshotActions[State any]( } resp := &GetSnapshotResponse[State]{ - SnapshotID: snap.SnapshotID, - CreatedAt: snap.CreatedAt, - UpdatedAt: updatedAt, - Status: status, - Error: snap.Error, + SnapshotID: snap.SnapshotID, + CreatedAt: snap.CreatedAt, + UpdatedAt: updatedAt, + Status: status, + FinishReason: snap.FinishReason, + Error: snap.Error, } if status != SnapshotStatusFailed && status != SnapshotStatusPending { resp.State = applyTransform(ctx, transform, snap.State) diff --git a/go/core/schemas.config b/go/core/schemas.config index 8e82cd3331..6eb306ad8a 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1304,6 +1304,12 @@ AgentResult.artifacts doc Artifacts contains artifacts produced during the session. . +AgentResult.finishReason doc +FinishReason is why the invocation finished. A custom agent may set it to +override the default (the last turn's reason); leave it empty to accept the +default. +. + # ---------------------------------------------------------------------------- # AgentOutput # ---------------------------------------------------------------------------- @@ -1335,6 +1341,13 @@ AgentOutput.artifacts doc Artifacts contains artifacts produced during the session. . +AgentOutput.finishReason doc +FinishReason is why the invocation finished. It is +[AgentFinishReasonDetached] when the client detached and the work continues +in the background; otherwise it is the last turn's reason (or the value a +custom agent set on its [AgentResult]). +. + # ---------------------------------------------------------------------------- # AgentStreamChunk # ---------------------------------------------------------------------------- @@ -1388,6 +1401,13 @@ Empty if no snapshot was created (callback returned false or no store configured, or snapshots were suspended after detach). . +TurnEnd.finishReason doc +FinishReason is why this turn finished (e.g. [AgentFinishReasonStop], +[AgentFinishReasonInterrupted]). It lets a caller react to a turn boundary +(such as pausing on an interrupt) without scanning the message content. +Empty when the turn reported no reason. +. + # ---------------------------------------------------------------------------- # SessionState # ---------------------------------------------------------------------------- @@ -1459,6 +1479,21 @@ Status is the lifecycle state of this snapshot. Empty is treated as [SnapshotStatusSucceeded] for backwards compatibility. . +SessionSnapshot.finishReason doc +FinishReason is the semantic reason the turn or invocation captured here +ended (e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], +[AgentFinishReasonFailed], [AgentFinishReasonAborted]). It complements +[SessionSnapshot.Status] (the persistence lifecycle) so a resumed or +background task can report how it ended without re-deriving it from the +messages. + +On an aborted snapshot this is best-effort and may briefly lag Status: +AbortSnapshot flips Status to [SnapshotStatusAborted] immediately, while the +[AgentFinishReasonAborted] finish reason is stamped by a subsequent finalizer +write. If the process running the invocation never finalizes, the reason can +remain empty even though Status is [SnapshotStatusAborted]. +. + SessionSnapshot.error type *core.GenkitError SessionSnapshot.error doc Error is the structured failure information for a snapshot in @@ -1545,6 +1580,70 @@ The snapshot's Error field describes the failure and resume is rejected with that same error. . +# ---------------------------------------------------------------------------- +# AgentFinishReason +# ---------------------------------------------------------------------------- + +AgentFinishReason pkg ai/exp + +AgentFinishReason doc +AgentFinishReason is why an agent turn or invocation finished. + +The first group mirrors the model-level [ai.FinishReason] so a turn backed +by a single generate call can forward its reason verbatim: +[AgentFinishReasonStop], [AgentFinishReasonLength], +[AgentFinishReasonBlocked], [AgentFinishReasonInterrupted], +[AgentFinishReasonOther] and [AgentFinishReasonUnknown]. + +The remaining values are agent-specific outcomes with no generate-level +equivalent (they never arise from forwarding a model finish reason): +[AgentFinishReasonAborted], [AgentFinishReasonDetached] and +[AgentFinishReasonFailed]. +. + +AgentFinishReasonStop doc +AgentFinishReasonStop indicates the model stopped naturally. +. + +AgentFinishReasonLength doc +AgentFinishReasonLength indicates generation hit the token limit. +. + +AgentFinishReasonBlocked doc +AgentFinishReasonBlocked indicates generation was blocked (e.g. safety). +. + +AgentFinishReasonAborted doc +AgentFinishReasonAborted indicates the turn or invocation was aborted +(e.g. a detached invocation aborted via the abortSnapshot companion action). +. + +AgentFinishReasonInterrupted doc +AgentFinishReasonInterrupted indicates the model paused on a tool request +awaiting input (e.g. human approval). The turn can be resumed by sending an +[AgentInput] with a Resume payload. +. + +AgentFinishReasonOther doc +AgentFinishReasonOther indicates generation stopped for some other reason. +. + +AgentFinishReasonUnknown doc +AgentFinishReasonUnknown indicates the reason was unspecified. +. + +AgentFinishReasonDetached doc +AgentFinishReasonDetached indicates the invocation was moved to the +background because the client detached. The returned [AgentOutput] reports +this reason; the persisted snapshot is later finalized with how the +background work actually ended. +. + +AgentFinishReasonFailed doc +AgentFinishReasonFailed indicates the turn or invocation terminated with an +error. +. + # GetSnapshotRequest GetSnapshotRequest pkg ai/exp @@ -1594,6 +1693,14 @@ GetSnapshotResponse.status doc Status is the lifecycle state of the snapshot. See [SnapshotStatus]. . +GetSnapshotResponse.finishReason doc +FinishReason is the semantic reason the captured turn or invocation ended +(e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], +[AgentFinishReasonFailed], [AgentFinishReasonAborted]). It lets a remote or +background poller report how a detached or resumed invocation ended without +re-deriving it. Empty on a pending snapshot. +. + GetSnapshotResponse.error type *core.GenkitError GetSnapshotResponse.error doc Error is the structured failure information; populated when Status diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index a7656e518f..96c1a8e4f5 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -522,22 +522,26 @@ func DefineAgent[State any]( // chatAgent := genkit.DefineCustomAgent(g, "chat", // func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { // var lastMessage *ai.Message -// err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { +// err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { +// var reason aix.AgentFinishReason // for result, err := range genkit.GenerateStream(ctx, g, // ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithMessages(sess.Messages()...), // ) { // if err != nil { -// return err +// return nil, err // } // if result.Done { // lastMessage = result.Response.Message +// reason = aix.AgentFinishReason(result.Response.FinishReason) // sess.AddMessages(lastMessage) // } else { // resp.SendModelChunk(result.Chunk) // } // } -// return nil +// // Report how the turn ended; the framework forwards it on +// // the TurnEnd chunk and persists it on the snapshot. +// return &aix.TurnResult{FinishReason: reason}, nil // }) // if err != nil { // return nil, err diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index f17558b73f..f746283060 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -352,6 +352,12 @@ repl: fmt.Print(chunk.ModelChunk.Text()) } if chunk.TurnEnd != nil { + // finishReason rides the turn-end signal, so the client + // learns how the turn ended without scanning the message + // content (e.g. it could pause here on "interrupted"). + if r := chunk.TurnEnd.FinishReason; r != "" { + fmt.Printf("\n [turn: %s]", r) + } if chunk.TurnEnd.SnapshotID != "" { fmt.Printf("\n [snapshot %s]", shortID(chunk.TurnEnd.SnapshotID)) } @@ -374,7 +380,7 @@ repl: fmt.Println("agent again from the list to wait for it to finalize and") fmt.Println("resume from the cumulative final state.") case out != nil && out.SnapshotID != "": - fmt.Printf("Done. Final snapshot: %s.\n", shortID(out.SnapshotID)) + fmt.Printf("Done (%s). Final snapshot: %s.\n", out.FinishReason, shortID(out.SnapshotID)) } if quit { diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index 643cba5885..11e69867b1 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -161,7 +161,7 @@ func defineCustomAgent(g *genkit.Genkit) sampleAgent { return sampleAgent{ Agent: genkit.DefineCustomAgent(g, name, func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) error { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ @@ -172,15 +172,20 @@ func defineCustomAgent(g *genkit.Genkit) sampleAgent { ai.WithMessages(sess.Messages()...), ) { if err != nil { - return err + return nil, err } if chunk.Done { sess.AddMessages(chunk.Response.Message) - break + // Report how the turn ended so the framework can + // forward it on the TurnEnd chunk and persist it + // on the snapshot. + return &aix.TurnResult{ + FinishReason: aix.AgentFinishReason(chunk.Response.FinishReason), + }, nil } resp.SendModelChunk(chunk.Chunk) } - return nil + return nil, nil }); err != nil { return nil, err } diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 5e9bd96343..cfcf4a4a84 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -34,6 +34,20 @@ ) +class AgentFinishReason(StrEnum): + """AgentFinishReason data type class.""" + + STOP = 'stop' + LENGTH = 'length' + BLOCKED = 'blocked' + INTERRUPTED = 'interrupted' + OTHER = 'other' + UNKNOWN = 'unknown' + ABORTED = 'aborted' + DETACHED = 'detached' + FAILED = 'failed' + + class AgentStateManagement(StrEnum): """AgentStateManagement data type class.""" @@ -151,6 +165,7 @@ class AgentOutput(GenkitModel): state: SessionState | None = None message: MessageData | None = None artifacts: list[Artifact] | None = None + finish_reason: AgentFinishReason | None = None class AgentResult(GenkitModel): @@ -159,6 +174,7 @@ class AgentResult(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) message: MessageData | None = None artifacts: list[Artifact] | None = None + finish_reason: AgentFinishReason | None = None class AgentStreamChunk(GenkitModel): @@ -195,6 +211,7 @@ class GetSnapshotResponse(GenkitModel): created_at: str | None = None updated_at: str | None = None status: SnapshotStatus | None = None + finish_reason: AgentFinishReason | None = None error: Any | None = Field(default=None) state: SessionState | None = None @@ -209,6 +226,7 @@ class SessionSnapshot(GenkitModel): updated_at: str | None = None event: SnapshotEvent = Field(...) status: SnapshotStatus | None = None + finish_reason: AgentFinishReason | None = None error: Error | None = None state: SessionState | None = None @@ -227,6 +245,7 @@ class TurnEnd(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) snapshot_id: str | None = None + finish_reason: AgentFinishReason | None = None class DocumentData(GenkitModel): From 166a4da4c4ccc9c17b1d95d6052603d2afba3469 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 9 Jun 2026 15:49:43 -0700 Subject: [PATCH 088/141] feat(go/core): errors.Is sentinels for BidiConnection.Send failures Send previously reported "action has completed" and "connection is closed" as bare FAILED_PRECONDITION errors, forcing callers to infer the cause by probing Output's shape. Send now wraps the new ErrActionCompleted and ErrConnectionClosed sentinels (same message bytes and status, so observable behavior is unchanged) and callers can branch with errors.Is. The new tests pin down two subtleties: Close races action completion (either sentinel can win unless the action is held open), and the buffered input channel can accept one post-completion send before the sentinel is guaranteed. --- go/core/action.go | 20 ++++++++++++++++---- go/core/flow_test.go | 41 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 585d692d71..8c3162083b 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -19,6 +19,7 @@ package core import ( "context" "encoding/json" + "errors" "iter" "reflect" "sync" @@ -538,8 +539,19 @@ type BidiConnection[StreamIn, StreamOut, Out any] struct { closed bool } -// Send sends an input message to the bidi action. -// Returns an error if the connection is closed or the context is cancelled. +// ErrConnectionClosed indicates a Send on a connection whose input side +// was closed with [BidiConnection.Close]. Test with [errors.Is]. +var ErrConnectionClosed = errors.New("connection is closed") + +// ErrActionCompleted indicates a Send on a connection whose action has +// already returned. Test with [errors.Is]; the action's result is +// available via [BidiConnection.Output]. +var ErrActionCompleted = errors.New("action has completed") + +// Send sends an input message to the bidi action. It fails with an error +// matching [ErrConnectionClosed] after [BidiConnection.Close], with one +// matching [ErrActionCompleted] once the action has returned, or with the +// context's error if the connection's context is cancelled. func (c *BidiConnection[StreamIn, StreamOut, Out]) Send(input StreamIn) (err error) { // Recover from "send on closed channel" panic. A check-then-send under the // mutex would race with Close, and holding the mutex across the send would @@ -548,7 +560,7 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Send(input StreamIn) (err err // canonical `for ... range inputCh` idiom. defer func() { if r := recover(); r != nil { - err = NewError(FAILED_PRECONDITION, "connection is closed") + err = NewError(FAILED_PRECONDITION, "%v", ErrConnectionClosed) } }() @@ -558,7 +570,7 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Send(input StreamIn) (err err case <-c.ctx.Done(): return c.ctx.Err() case <-c.doneCh: - return NewError(FAILED_PRECONDITION, "action has completed") + return NewError(FAILED_PRECONDITION, "%v", ErrActionCompleted) } } diff --git a/go/core/flow_test.go b/go/core/flow_test.go index 5810ac4625..f240d62bd1 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -18,6 +18,7 @@ package core import ( "context" + "errors" "fmt" "slices" "strings" @@ -360,9 +361,13 @@ func TestBidiActionWithConfig(t *testing.T) { func TestBidiConnectionSendAfterClose(t *testing.T) { ctx := context.Background() + // Hold the action open so the only Send failure mode is the closed + // input side (a completed action would race in ErrActionCompleted). + release := make(chan struct{}) action := NewBidiAction( "test", api.ActionTypeCustom, nil, func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + <-release for range inCh { } return "", nil @@ -375,11 +380,41 @@ func TestBidiConnectionSendAfterClose(t *testing.T) { } conn.Close() - // Wait for completion so we know the state is settled. + if err := conn.Send("after close"); !errors.Is(err, ErrConnectionClosed) { + t.Errorf("expected error matching ErrConnectionClosed, got %v", err) + } + + close(release) <-conn.Done() +} + +func TestBidiConnectionSendAfterCompletion(t *testing.T) { + ctx := context.Background() - if err := conn.Send("after close"); err == nil { - t.Error("expected error sending after close") + action := NewBidiAction( + "quick", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + return "done", nil // return without consuming inputs + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + <-conn.Done() + + // The input channel is buffered, so the first send after completion + // may still be accepted (and dropped); once the buffer is full a send + // must fail with the sentinel. + var sendErr error + for range 3 { + if sendErr = conn.Send("late"); sendErr != nil { + break + } + } + if !errors.Is(sendErr, ErrActionCompleted) { + t.Errorf("expected error matching ErrActionCompleted, got %v", sendErr) } } From 4a1396005ada98dac04e4365329dda553c278ab0 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 9 Jun 2026 15:49:58 -0700 Subject: [PATCH 089/141] feat(go/exp): resolve agent failures gracefully with last-good state A failed turn previously failed the whole invocation: the action returned an error, no output was produced, and with it went every successfully completed prior turn (client-managed state lives only on the output; server-managed turns skipped by a selective snapshot callback were never persisted). One transient model or tool error could cost the whole conversation. The agent no longer fails the action for in-band failures. Every error (turn, panic, init validation, detach precondition) resolves the invocation with finishReason "failed", the error on the new AgentOutput.error field (original status intact, so callers can still branch on e.g. INVALID_ARGUMENT vs INTERNAL), and the last-good state: the state the failed turn started with, excluding its partial mutations. - SessionRunner captures the last-good state after every successful turn, whether or not the snapshot callback persisted it; the deep copy is skipped when the newest snapshot already covers that version. - Client-managed: AgentOutput.state carries the last-good state. - Server-managed: AgentOutput.snapshotId points at the newest snapshot capturing the last-good state. When a selective callback had skipped it, a recovery snapshot is written retroactively (new SnapshotEventRecovery, status succeeded, the last good turn's finish reason), deliberately bypassing the callback. - A failed turn emits TurnEnd{finishReason: "failed"} with no snapshot of its partial state and ends the invocation (further sends fail with core.ErrActionCompleted); a custom agent may instead swallow Run's error and keep processing, and intake pacing supports that. - AgentOutput.error rides the shared AgentError wire schema, now also reused by SessionSnapshot.error; Agent.Run resolves rejected inits to the failed output via the Send sentinel; the basic-agents sample demonstrates the stop-on-failed-TurnEnd client pattern. Transport-level failures (connection setup, client disconnect) still return errors. --- genkit-tools/common/src/types/agent.ts | 65 +- genkit-tools/genkit-schema.json | 37 +- go/ai/exp/agent.go | 251 +++++-- go/ai/exp/agent_test.go | 668 ++++++++++++++++-- go/ai/exp/gen.go | 34 +- go/ai/exp/teststore_test.go | 7 + go/core/schemas.config | 49 +- go/samples/basic-agents/cli.go | 15 + .../genkit/src/genkit/_core/_typing.py | 29 +- 9 files changed, 971 insertions(+), 184 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 9643e0b247..300c2e7e90 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -45,11 +45,17 @@ export type Artifact = z.infer; * `pending` status (and empty state) and rewritten with a terminal * status and the final cumulative state once the background work * finishes. + * - `recovery`: snapshot was written retroactively by the failure path to + * preserve the last-good state (everything through the last successful + * turn) when a selective snapshot callback had skipped persisting it. + * It is a normal `succeeded` row carrying the last good turn's + * `finishReason`, resumable like any other; the callback is bypassed. */ export const SnapshotEventSchema = z.enum([ 'turnEnd', 'invocationEnd', 'detach', + 'recovery', ]); export type SnapshotEvent = z.infer; @@ -175,13 +181,37 @@ export const AgentResultSchema = z.object({ }); export type AgentResult = z.infer; +/** + * Zod schema for the canonical Genkit error wire shape + * (`{status, message, details}`), carried on agent outputs and snapshots. + */ +export const AgentErrorSchema = z.object({ + /** Canonical status name (e.g. `INTERNAL`, `FAILED_PRECONDITION`). */ + status: z.string().optional(), + /** Human-readable error message. */ + message: z.string(), + /** Optional structured details describing the failure. */ + details: z.any().optional(), +}); +export type AgentError = z.infer; + /** * Zod schema for agent output. */ export const AgentOutputSchema = z.object({ - /** ID of the snapshot created at the end of this invocation. */ + /** + * ID of the snapshot created at the end of this invocation. When + * `finishReason` is `failed` (and a store is configured), this is the + * most recent snapshot capturing the last-good state: everything + * through the last successful turn (see the `recovery` snapshot event). + */ snapshotId: z.string().optional(), - /** Final conversation state (only when client-managed). */ + /** + * Final conversation state (only when client-managed). When + * `finishReason` is `failed`, this is the last-good state: everything + * through the last successful turn, excluding the failed turn's + * partial mutations. + */ state: SessionStateSchema.optional(), /** Last model response message from the conversation. */ message: MessageSchema.optional(), @@ -189,10 +219,18 @@ export const AgentOutputSchema = z.object({ artifacts: z.array(ArtifactSchema).optional(), /** * Why the invocation finished. `detached` when the client detached and - * the work continues in the background; otherwise the last turn's reason + * the work continues in the background; `failed` when the invocation + * ended in failure (see `error`); otherwise the last turn's reason * (or the value a custom agent set on its result). */ finishReason: AgentFinishReasonSchema.optional(), + /** + * Structured failure information when the invocation ended in failure + * (`finishReason` is `failed`). `status` preserves the original error + * category (e.g. `INVALID_ARGUMENT`, `FAILED_PRECONDITION`, + * `INTERNAL`) so callers can still branch on it. + */ + error: AgentErrorSchema.optional(), }); export type AgentOutput = z.infer; @@ -215,6 +253,10 @@ export const TurnEndSchema = z.object({ * Why this turn finished (e.g. `stop`, `length`, `interrupted`). Lets a * caller react to a turn boundary (e.g. pause on `interrupted`) without * scanning the message content. Omitted when the turn reported no reason. + * + * `failed` reports a failed turn; unless the agent recovers and keeps + * processing, the invocation then resolves with a failed output carrying + * the error and the last-good state. */ finishReason: AgentFinishReasonSchema.optional(), }); @@ -262,25 +304,10 @@ export const SessionSnapshotSchema = z.object({ * `stop`, `interrupted`, `failed`, `aborted`). Complements `status` (the * persistence lifecycle) so a resumed or background task can report how it * ended without re-deriving it from the messages. - * - * On an aborted snapshot this is best-effort and may briefly lag `status`: - * `abortSnapshot` flips `status` to `aborted` immediately, while the - * `aborted` finish reason is stamped by a subsequent finalizer write. If - * the process running the invocation never finalizes, the reason can - * remain empty even though `status` is `aborted`. */ finishReason: AgentFinishReasonSchema.optional(), /** Structured failure information for a snapshot in `failed` status. */ - error: z - .object({ - /** Canonical status name (e.g. `INTERNAL`, `FAILED_PRECONDITION`). */ - status: z.string().optional(), - /** Human-readable error message. */ - message: z.string(), - /** Optional structured details describing the failure. */ - details: z.any().optional(), - }) - .optional(), + error: AgentErrorSchema.optional(), /** * Conversation state captured at this point. Empty on a pending snapshot * (the live state is not yet committed); populated on terminal snapshots diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index c478cd7d65..62e9127891 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -28,6 +28,22 @@ ], "additionalProperties": false }, + "AgentError": { + "type": "object", + "properties": { + "status": { + "type": "string" + }, + "message": { + "type": "string" + }, + "details": {} + }, + "required": [ + "message" + ], + "additionalProperties": false + }, "AgentFinishReason": { "type": "string", "enum": [ @@ -120,6 +136,9 @@ }, "finishReason": { "$ref": "#/$defs/AgentFinishReason" + }, + "error": { + "$ref": "#/$defs/AgentError" } }, "additionalProperties": false @@ -252,20 +271,7 @@ "$ref": "#/$defs/AgentFinishReason" }, "error": { - "type": "object", - "properties": { - "status": { - "type": "string" - }, - "message": { - "type": "string" - }, - "details": {} - }, - "required": [ - "message" - ], - "additionalProperties": false + "$ref": "#/$defs/AgentError" }, "state": { "$ref": "#/$defs/SessionState" @@ -302,7 +308,8 @@ "enum": [ "turnEnd", "invocationEnd", - "detach" + "detach", + "recovery" ] }, "SnapshotStatus": { diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 79b66642f4..84c05f8962 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -22,6 +22,7 @@ package exp import ( "context" + "errors" "fmt" "iter" "runtime/debug" @@ -61,19 +62,31 @@ type SessionRunner[State any] struct { // lastTurnFinishReason is the finish reason reported by the most recent // turn (via the [TurnResult] its callback returned), or "" if the turn - // reported none. It is written by Run before [SessionRunner.onEndTurn] + // reported none. It is written by endTurn before [SessionRunner.onEndTurn] // and read by the runtime when emitting the turn-end signal and when // defaulting the invocation's finish reason. All accesses are confined // to the fn goroutine (Run and its synchronous onEndTurn callback) until // fn completes, after which the terminal paths read it with a // happens-before edge through the fnDone channel, so no lock is needed. + // The same confinement applies to lastTurnFailed and the lastGood* + // fields below; the terminal paths that read them (handleFnDone and + // the detach-failure paths) all wait on fnDone first. lastTurnFinishReason AgentFinishReason - // intake is the source of truth for in-flight tracking, queue state, - // and suspended state. The session consults it via beginTurnEnd (in - // maybeSnapshot) so per-turn snapshot writes and detach captures - // cannot race over the same input. - intake *detachIntake + // lastTurnFailed reports whether the most recent turn ended in error. + // Set by endTurn each turn. + lastTurnFailed bool + + // lastGoodState is a deep copy of the session state as of the most + // recent successful turn (or the initial state when no turn has + // completed yet), captured regardless of whether the snapshot callback + // persisted that turn. lastGoodVersion is the session version at that + // capture and lastGoodFinishReason that turn's reported reason. The + // failure path returns (client-managed) or persists (server-managed + // recovery snapshot) this state. + lastGoodState *SessionState[State] + lastGoodVersion uint64 + lastGoodFinishReason AgentFinishReason } // parentSnapshotID returns the ID of the most recent snapshot in this @@ -108,6 +121,14 @@ type TurnResult struct { // fn may return a [TurnResult] to report how the turn ended (e.g. its finish // reason); returning nil reports nothing. The reason rides the turn's // [TurnEnd] chunk and is persisted on the turn-end snapshot. +// +// When fn returns an error, Run records the failure ([TurnEnd] is emitted +// with [AgentFinishReasonFailed] and no snapshot is taken of the turn's +// partial state), stops looping, and returns the error. A custom agent may +// recover (e.g. call Run again to keep processing inputs) or propagate the +// error out of the agent function, which resolves the invocation with a +// failed [AgentOutput] carrying the error and the last-good state rather +// than failing the action. func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) (*TurnResult, error)) error { for input := range s.InputCh { spanMeta := &tracing.SpanMetadata{ @@ -124,14 +145,12 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont if err != nil { return nil, err } - // Reset each turn; a returned TurnResult sets the reason, - // nil reports none. - s.lastTurnFinishReason = "" + // A returned TurnResult sets the reason, nil reports none. + var reason AgentFinishReason if tr != nil { - s.lastTurnFinishReason = tr.FinishReason + reason = tr.FinishReason } - s.onEndTurn(ctx) - s.TurnIndex++ + s.endTurn(ctx, reason, false) if s.collectTurnOutput != nil { return s.collectTurnOutput(), nil } @@ -139,12 +158,46 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont }, ) if err != nil { + s.endTurn(ctx, AgentFinishReasonFailed, true) return err } } return nil } +// endTurn records how the turn ended and runs the shared turn-end tail: +// the turn-end emit, the last-good capture on success, and the turn +// advance. +func (s *SessionRunner[State]) endTurn(ctx context.Context, reason AgentFinishReason, failed bool) { + s.lastTurnFinishReason = reason + s.lastTurnFailed = failed + s.onEndTurn(ctx) + if !failed { + s.recordLastGood() + } + s.TurnIndex++ +} + +// recordLastGood captures the current session state as the last-good +// recovery point. Called once at session start and after every successful +// turn, whether or not the snapshot callback persisted that turn. Runs +// after the turn-end snapshot check so that when the newest snapshot +// already captures this exact version, the deep copy is skipped; +// recoverySnapshotID then resolves to that snapshot's ID without reading +// lastGoodState. +func (s *SessionRunner[State]) recordLastGood() { + s.mu.RLock() + version := s.version + persisted := s.lastSnapshot != nil && version == s.lastSnapshotVersion + if !persisted { + state := s.copyStateLocked() + s.lastGoodState = &state + } + s.mu.RUnlock() + s.lastGoodVersion = version + s.lastGoodFinishReason = s.lastTurnFinishReason +} + // Result returns an [AgentResult] populated from the current session state: // the last message in the conversation history and all artifacts. The // returned value is independent of the session; callers may mutate it @@ -177,23 +230,10 @@ func (s *SessionRunner[State]) invocationReason(result *AgentResult) AgentFinish } // maybeSnapshot creates a snapshot if conditions are met (store configured, -// callback approves, state changed, detach has not suspended snapshots). -// Returns the snapshot ID or empty string. finishReason is recorded on the -// snapshot so a resumed or background task can report how the captured turn -// or invocation ended. -// -// For turn-end events, the session asks the intake whether snapshots -// have been suspended (i.e. detach has landed). If so, the session skips -// the turn-end snapshot — the pending row already captures the -// invocation and a single finalize rewrite will record the cumulative -// state once the queued inputs drain. +// callback approves, state changed). Returns the snapshot ID or empty +// string. finishReason is recorded on the snapshot so a resumed or +// background task can report how the captured turn or invocation ended. func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason) string { - if event == SnapshotEventTurnEnd && s.intake != nil { - if suspended := s.intake.beginTurnEnd(); suspended { - return "" - } - } - if s.store == nil { return "" } @@ -231,6 +271,17 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot } } + return s.persistSnapshot(ctx, event, finishReason, ¤tState, currentVersion) +} + +// persistSnapshot writes a succeeded snapshot row capturing state (at the +// given session version), chained to the newest persisted snapshot, and +// advances the lastSnapshot bookkeeping. Both the routine cadence +// (maybeSnapshot) and the failure path (recoverySnapshotID) funnel through +// here so the row shape and bookkeeping live in one place. Persistence is +// best-effort: a store failure must not kill the in-flight turn, so it is +// logged and "" is returned. +func (s *SessionRunner[State]) persistSnapshot(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason, state *SessionState[State], version uint64) string { parentID := s.parentSnapshotID() saved, err := s.store.SaveSnapshot(ctx, "", @@ -240,14 +291,10 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot Event: event, Status: SnapshotStatusSucceeded, FinishReason: finishReason, - State: ¤tState, + State: state, }, nil }) if err != nil { - // Snapshot persistence is best-effort: a store failure must not - // kill the in-flight turn. Surface enough context in the log - // that the failure is diagnosable without the caller having to - // thread the error back up. logger.FromContext(ctx).Error("agent: failed to save snapshot", "parentId", parentID, "event", event, @@ -256,10 +303,38 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot } s.lastSnapshot = saved - s.lastSnapshotVersion = currentVersion + s.lastSnapshotVersion = version return saved.SnapshotID } +// recoverySnapshotID returns the ID of a snapshot holding the last-good +// state, writing one (event [SnapshotEventRecovery]) when the newest +// persisted snapshot is behind it. The write uses the captured +// lastGoodState, never the live state (which may hold the failed turn's +// partial mutations), and intentionally bypasses the snapshot callback so +// a selective cadence cannot lose the conversation. If the write fails, +// the newest persisted snapshot's ID is returned instead. +// +// Returns "" when no store is configured or there is nothing to recover +// (no snapshot exists and no turn ever changed state). +func (s *SessionRunner[State]) recoverySnapshotID(ctx context.Context) string { + if s.store == nil { + return "" + } + // The newest snapshot already captures exactly the last-good state. + if s.lastSnapshot != nil && s.lastGoodVersion == s.lastSnapshotVersion { + return s.lastSnapshot.SnapshotID + } + if s.lastSnapshot == nil && s.lastGoodVersion == 0 { + return "" + } + + if id := s.persistSnapshot(ctx, SnapshotEventRecovery, s.lastGoodFinishReason, s.lastGoodState, s.lastGoodVersion); id != "" { + return id + } + return s.parentSnapshotID() +} + // --- Responder --- // Responder is the output channel for an agent. Artifacts sent through @@ -401,7 +476,13 @@ func DefineCustomAgent[Stream, State any]( ctx = core.WithFlowContext(ctx, name) rt, err := newAgentRuntime(ctx, name, cfg, in, inCh, outCh) if err != nil { - return nil, err + // Init validation failures resolve as a failed output + // too. No state is attached; the invocation never + // started, so the caller lost nothing. + return &AgentOutput[State]{ + FinishReason: AgentFinishReasonFailed, + Error: core.AsGenkitError(err), + }, nil } return rt.run(ctx, fn) }) @@ -480,20 +561,32 @@ func newAgentRuntime[Stream, State any]( InputCh: rt.intake.out(), snapshotCallback: cfg.callback, lastSnapshot: parent, - intake: rt.intake, } rt.sess.collectTurnOutput = func() any { return rt.router.collectTurnChunks() } rt.sess.onEndTurn = rt.emitTurnEnd + // The initial state (fresh, client-provided, or loaded from a + // snapshot) is the last-good recovery point until a turn completes. + rt.sess.recordLastGood() return rt, nil } -// emitTurnEnd is called by the session after each successful turn. It writes -// a turn-end snapshot (if applicable) and forwards the resulting [TurnEnd] -// chunk through the router so clients see it on the output stream. +// emitTurnEnd is called by the session after each turn. It paces the +// intake (releasing the forwarder for the next input), writes a turn-end +// snapshot (if applicable), and forwards the resulting [TurnEnd] chunk +// through the router so clients see it on the output stream. +// +// The snapshot is skipped when the turn failed (the live state holds the +// turn's partial mutations) and when detach has landed (the pending row +// already captures the invocation; a single finalize rewrite records the +// cumulative state once the queued inputs drain). func (rt *agentRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { + suspended := rt.intake.beginTurnEnd() reason := rt.sess.lastTurnFinishReason - snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd, reason) + var snapshotID string + if !rt.sess.lastTurnFailed && !suspended { + snapshotID = rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd, reason) + } rt.router.sendChunk(ctx, &AgentStreamChunk[Stream]{TurnEnd: &TurnEnd{ SnapshotID: snapshotID, FinishReason: reason, @@ -549,7 +642,7 @@ func (rt *agentRuntime[Stream, State]) run( case <-rt.intake.detachSignal(): if err := rt.checkDetachCapabilities(); err != nil { rt.drainAndWait(cancelWork) - return nil, err + return rt.failedOutput(clientCtx, err), nil } return rt.handleDetach(clientCtx, workCtx, cancelWork, markDetached) @@ -602,7 +695,9 @@ func (rt *agentRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFun // handleFnDone is the synchronous-completion path: fn returned before any // detach signal. Capture an invocation-end snapshot if state advanced past -// the last turn-end snapshot, then assemble the output. +// the last turn-end snapshot, then assemble the output. When fn returned +// an error, the invocation resolves gracefully as a failed output instead +// (see failedOutput). // // When fn returns with an error, the Responder's ctx-aware send may have // dropped a chunk while the router was still pinned on a downstream send @@ -625,7 +720,15 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( rt.router.close() if res.err != nil { - return nil, res.err + // A disconnect-driven failure keeps its error semantics: the + // client is gone, so there is no one to hand a graceful failed + // output to. The clientCtx.Done arm of the run select handles the + // common ordering; this guards the race where fn observes the + // cancellation first and its result wins the select. + if ctx.Err() != nil { + return nil, res.err + } + return rt.failedOutput(ctx, res.err), nil } invocationReason := rt.sess.invocationReason(res.result) @@ -655,6 +758,24 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( return out, nil } +// failedOutput assembles the output for an invocation that ended in +// failure: [AgentFinishReasonFailed], the error with its original status, +// and the last-good state (inline when client-managed, behind a recovery +// snapshot ID when server-managed). Message and Artifacts are left empty; +// they describe the result of a completed run. +func (rt *agentRuntime[Stream, State]) failedOutput(ctx context.Context, cause error) *AgentOutput[State] { + out := &AgentOutput[State]{ + FinishReason: AgentFinishReasonFailed, + Error: core.AsGenkitError(cause), + } + if rt.cfg.store == nil { + out.State = applyTransform(ctx, rt.cfg.transform, rt.sess.lastGoodState) + } else { + out.SnapshotID = rt.sess.recoverySnapshotID(ctx) + } + return out +} + // handleDetach commits the pending snapshot, returns its ID, and spawns the // status-subscriber and finalizer goroutines that own the rest of the // invocation. Per-turn snapshots are suspended for the remainder so the @@ -688,8 +809,8 @@ func (rt *agentRuntime[Stream, State]) handleDetach( }) if err != nil { rt.drainAndWait(cancelWork) - return nil, core.NewError(core.INTERNAL, - "agent %q: detach: save pending snapshot: %v", rt.name, err) + return rt.failedOutput(clientCtx, core.NewError(core.INTERNAL, + "agent %q: detach: save pending snapshot: %v", rt.name, err)), nil } // The router can no longer write to outCh once we return; the bidi @@ -1005,8 +1126,8 @@ func (r *chunkRouter[Stream, State]) close() { // The forwarder goroutine pops the queue and writes to dst, blocking on // the runner via turnDone so it stays in step with turn pacing. // -// The runner asks beginTurnEnd at the end of each turn: if suspended -// (detach has landed), the runner skips its turn-end snapshot — the +// The runtime's emitTurnEnd asks beginTurnEnd at the end of each turn: +// if suspended (detach has landed), it skips the turn-end snapshot — the // pending row already captures the invocation and a single finalize // will rewrite it with the cumulative state once the queued inputs // drain. If not suspended, a normal turn-end snapshot is written. @@ -1235,9 +1356,9 @@ func (i *detachIntake) detachSignal() <-chan struct{} { return i.detachCh } -// beginTurnEnd is called by [SessionRunner.maybeSnapshot] before writing -// a turn-end snapshot. If the intake has been suspended (detach landed), -// it returns suspended=true and the runner skips the snapshot. +// beginTurnEnd is called by the runtime's emitTurnEnd at each turn end, +// before any turn-end snapshot. If the intake has been suspended (detach +// landed), it returns suspended=true and the caller skips the snapshot. // // In all cases (including suspended) the forwarder is released so it can // pop the next queued input — suspension stops snapshot writing, not @@ -1410,6 +1531,9 @@ func (a *Agent[Stream, State]) StreamBidi( // Run starts a single-turn agent invocation with the given input. // It sends the input, waits for the agent to complete, and returns the output. // For multi-turn interactions or streaming, use StreamBidi instead. +// +// In-band failures (a failed turn, a rejected init payload) resolve as a +// failed [AgentOutput] rather than an error; see [AgentConnection.Output]. func (a *Agent[Stream, State]) Run( ctx context.Context, input *AgentInput, @@ -1419,15 +1543,10 @@ func (a *Agent[Stream, State]) Run( if err != nil { return nil, err } - // If the bidi function fails fast (e.g. resuming from an errored - // snapshot rejects in newAgentRuntime), Send sees a closed connection - // and returns a generic "action has completed" error. The real fn - // error is on Output(). Prefer it whenever it's non-nil so callers - // get the meaningful failure. - if err := conn.Send(input); err != nil { - if _, outErr := conn.Output(); outErr != nil { - return nil, outErr - } + // The invocation may resolve before consuming the input (e.g. an init + // validation failure resolves as a failed output); the outcome is on + // Output regardless. + if err := conn.Send(input); err != nil && !errors.Is(err, core.ErrActionCompleted) { return nil, err } return conn.Output() @@ -1472,6 +1591,11 @@ type AgentConnection[Stream, State any] struct { } // Send sends an AgentInput to the agent. +// +// Once the invocation has resolved (e.g. a failed turn ended it), Send +// fails with an error matching [core.ErrActionCompleted]; the outcome is +// on [AgentConnection.Output]. The same applies to the SendMessage, +// SendText, SendResume, and Detach helpers. func (c *AgentConnection[Stream, State]) Send(input *AgentInput) error { return c.conn.Send(input) } @@ -1537,6 +1661,15 @@ func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[S // would block on chunk sends and never reach completion. Calling Close // before Output is allowed but redundant; both are idempotent. // +// In-band failures resolve rather than error: a failed turn or a +// rejected request returns an [AgentOutput] with +// [AgentFinishReasonFailed], the error on [AgentOutput.Error] (original +// status intact), and the last-good state on [AgentOutput.State] +// (client-managed) or behind [AgentOutput.SnapshotID] (server-managed), +// so a failure costs the caller only the failed turn, never the +// session. A non-nil error here means the invocation itself could not +// run to a result (e.g. the connection's context was cancelled). +// // Output is itself idempotent: subsequent calls return the same // (*AgentOutput, error) from cache. The returned pointer is shared // across calls; treat it as read-only. diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index b74241e130..e275dc69d3 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "testing" "time" @@ -580,9 +581,511 @@ func TestAgent_ErrorInTurn(t *testing.T) { conn.SendText("trigger error") conn.Close() - _, err = conn.Output() - if err == nil { - t.Fatal("expected error from failed turn") + // A failed turn resolves the invocation gracefully rather than + // failing the action: the outcome is on the output. + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.Error == nil || !strings.Contains(out.Error.Message, "turn failed") { + t.Errorf("expected output error containing %q, got %+v", "turn failed", out.Error) + } + // Client-managed: the last-good state rides the output. No turn + // succeeded, so it is the initial empty state — excluding the failed + // turn's partial mutations (the user message added before fn ran). + if out.State == nil { + t.Fatal("expected last-good state on failed output") + } + if got := len(out.State.Messages); got != 0 { + t.Errorf("expected 0 messages in last-good state, got %d", got) + } +} + +// defineLastGoodTestAgent defines a client- or server-managed echo agent +// whose turn fails (with partial session mutations) when the user sends +// "boom". Successful turns report [AgentFinishReasonStop]. +func defineLastGoodTestAgent(reg api.Registry, name string, opts ...AgentOption[testState]) *Agent[testStatus, testState] { + return DefineCustomAgent(reg, name, + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + text := input.Message.Content[0].Text + if text == "boom" { + // Partial mutations of the failing turn: these must not + // leak into the recovered last-good state. + sess.AddMessages(ai.NewModelTextMessage("partial reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter = 999 + return s + }) + return nil, core.NewError(core.UNAVAILABLE, "model timeout") + } + sess.AddMessages(ai.NewModelTextMessage("echo: " + text)) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }); err != nil { + return nil, err + } + return sess.Result(), nil + }, + opts..., + ) +} + +func TestAgent_FailedTurn_ClientManagedReturnsLastGoodState(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := defineLastGoodTestAgent(reg, "lastGoodClient") + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + for _, text := range []string{"one", "two", "boom"} { + if err := conn.SendText(text); err != nil { + t.Fatalf("SendText(%q): %v", text, err) + } + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.Error == nil { + t.Fatal("expected error on failed output") + } + if out.Error.Status != core.UNAVAILABLE { + t.Errorf("expected original status %q preserved, got %q", core.UNAVAILABLE, out.Error.Status) + } + if !strings.Contains(out.Error.Message, "model timeout") { + t.Errorf("expected error message to contain %q, got %q", "model timeout", out.Error.Message) + } + + // The last-good state holds both successful turns (user + echo each) + // and excludes the failed turn entirely: no "boom" user message, no + // partial reply, no counter clobber. + if out.State == nil { + t.Fatal("expected last-good state on failed output") + } + if got := len(out.State.Messages); got != 4 { + t.Fatalf("expected 4 messages in last-good state, got %d", got) + } + if got := out.State.Messages[3].Content[0].Text; got != "echo: two" { + t.Errorf("expected last message %q, got %q", "echo: two", got) + } + if got := out.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2 in last-good state, got %d", got) + } +} + +func TestAgent_FailedTurn_LastGoodStateIsResumable(t *testing.T) { + // On failure, the client resumes a fresh invocation from the + // last-good state in the failed output. + ctx := context.Background() + reg := newTestRegistry(t) + + af := defineLastGoodTestAgent(reg, "lastGoodResume") + + out, err := af.RunText(ctx, "boom", WithState(&SessionState[testState]{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("one"), + ai.NewModelTextMessage("echo: one"), + }, + Custom: testState{Counter: 1}, + })) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Fatalf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + // The failed output echoes back the state the failed turn started + // with: exactly what the client sent. + if got := len(out.State.Messages); got != 2 { + t.Fatalf("expected 2 messages in last-good state, got %d", got) + } + + retry, err := af.RunText(ctx, "two", WithState(out.State)) + if err != nil { + t.Fatalf("RunText(retry): %v", err) + } + if retry.FinishReason != AgentFinishReasonStop { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonStop, retry.FinishReason) + } + if retry.Error != nil { + t.Errorf("expected no error on retry, got %+v", retry.Error) + } + if got := len(retry.State.Messages); got != 4 { + t.Errorf("expected 4 messages after retry, got %d", got) + } + if got := retry.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2 after retry, got %d", got) + } +} + +func TestAgent_FailedTurn_RecoverySnapshotBypassesCallback(t *testing.T) { + // Server-managed agent with a selective snapshot callback that only + // persists the first turn. The second (successful) turn is skipped by + // the callback, so when the third turn fails, the runtime must write a + // retroactive recovery snapshot of the last-good state — bypassing the + // callback — or the skipped turn would be lost. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + // The callback runs on the fn goroutine; the assertions below run + // after Output() returns, which happens-after fn completes, so no + // locking is needed. + var cbEvents []SnapshotEvent + af := defineLastGoodTestAgent(reg, "recoverySnapshot", + WithSessionStore[testState](store), + WithSnapshotCallback(func(_ context.Context, sc *SnapshotContext[testState]) bool { + cbEvents = append(cbEvents, sc.Event) + return sc.Event == SnapshotEventTurnEnd && sc.TurnIndex == 0 + }), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + var turnEnds []*TurnEnd + for _, text := range []string{"one", "two"} { + if err := conn.SendText(text); err != nil { + t.Fatalf("SendText(%q): %v", text, err) + } + turnEnds = append(turnEnds, nextTurnEnd(t, conn)) + } + if turnEnds[0].SnapshotID == "" { + t.Fatal("expected turn 0 snapshot to be persisted") + } + if turnEnds[1].SnapshotID != "" { + t.Fatalf("expected turn 1 snapshot to be skipped by callback, got %q", turnEnds[1].SnapshotID) + } + + if err := conn.SendText("boom"); err != nil { + t.Fatalf("SendText(boom): %v", err) + } + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.Error == nil || out.Error.Status != core.UNAVAILABLE { + t.Fatalf("expected error with status %q, got %+v", core.UNAVAILABLE, out.Error) + } + if out.State != nil { + t.Errorf("server-managed failed output must not carry inline state, got %+v", out.State) + } + if out.SnapshotID == "" { + t.Fatal("expected recovery snapshot ID on failed output") + } + if out.SnapshotID == turnEnds[0].SnapshotID { + t.Fatal("expected a fresh recovery snapshot, got the turn-0 snapshot") + } + + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil || snap == nil { + t.Fatalf("GetSnapshot(%q): %v, %v", out.SnapshotID, snap, err) + } + if snap.Status != SnapshotStatusSucceeded { + t.Errorf("expected recovery snapshot status %q, got %q", SnapshotStatusSucceeded, snap.Status) + } + if snap.Event != SnapshotEventRecovery { + t.Errorf("expected recovery snapshot event %q, got %q", SnapshotEventRecovery, snap.Event) + } + if snap.FinishReason != AgentFinishReasonStop { + t.Errorf("expected recovery snapshot to carry the last good turn's reason %q, got %q", + AgentFinishReasonStop, snap.FinishReason) + } + if snap.ParentID != turnEnds[0].SnapshotID { + t.Errorf("expected recovery snapshot parent %q, got %q", turnEnds[0].SnapshotID, snap.ParentID) + } + // State through the last successful turn, excluding the failed turn. + if got := len(snap.State.Messages); got != 4 { + t.Fatalf("expected 4 messages in recovery snapshot, got %d", got) + } + if got := snap.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2 in recovery snapshot, got %d", got) + } + + // The recovery write bypassed the callback: it was consulted once per + // successful turn only (the failed turn and the recovery write never + // ask). + if want := []SnapshotEvent{SnapshotEventTurnEnd, SnapshotEventTurnEnd}; !slices.Equal(cbEvents, want) { + t.Errorf("expected callback events %v, got %v", want, cbEvents) + } +} + +func TestAgent_FailedTurn_LastGoodAlreadyPersisted_NoRecoveryWrite(t *testing.T) { + // With the default always-snapshot cadence, the last-good state is + // already in the store when a turn fails: the output reuses that + // snapshot's ID and no extra row is written. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := defineLastGoodTestAgent(reg, "recoveryDedup", WithSessionStore[testState](store)) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("one"); err != nil { + t.Fatalf("SendText: %v", err) + } + turn0 := nextTurnEnd(t, conn) + if turn0.SnapshotID == "" { + t.Fatal("expected turn 0 snapshot") + } + + if err := conn.SendText("boom"); err != nil { + t.Fatalf("SendText(boom): %v", err) + } + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.SnapshotID != turn0.SnapshotID { + t.Errorf("expected failed output to reuse the persisted last-good snapshot %q, got %q", + turn0.SnapshotID, out.SnapshotID) + } + if rows := store.snapshotCount(); rows != 1 { + t.Errorf("expected no recovery row when last-good is already persisted, got %d rows", rows) + } +} + +func TestAgent_FailedFirstTurn_AfterResume_ReturnsParentSnapshotID(t *testing.T) { + // Resuming from a snapshot and failing before any turn completes: + // the parent snapshot already captures the last-good state, so the + // failed output points back at it and no recovery row is written. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + parent, err := store.SaveSnapshot(ctx, "", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Event: SnapshotEventInvocationEnd, + Status: SnapshotStatusSucceeded, + State: &SessionState[testState]{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("one"), + ai.NewModelTextMessage("echo: one"), + }, + Custom: testState{Counter: 1}, + }, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + + af := defineLastGoodTestAgent(reg, "resumeFailFirst", WithSessionStore[testState](store)) + + out, err := af.RunText(ctx, "boom", WithSnapshotID[testState](parent.SnapshotID)) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.SnapshotID != parent.SnapshotID { + t.Errorf("expected failed output to return parent snapshot %q, got %q", + parent.SnapshotID, out.SnapshotID) + } + if rows := store.snapshotCount(); rows != 1 { + t.Errorf("expected no new rows on first-turn failure after resume, got %d", rows) + } +} + +func TestAgent_FailedTurn_EmitsFailedTurnEnd(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + // Hold the agent fn open until the client has consumed the failed + // TurnEnd, so the chunk delivery is deterministic (the runtime stops + // forwarding chunks once fn returns with an error). + turnEndSeen := make(chan struct{}) + af := DefineCustomAgent(reg, "failedTurnEnd", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, fmt.Errorf("boom") + }) + select { + case <-turnEndSeen: + case <-time.After(2 * time.Second): + } + return nil, err + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + + turnEnd := nextTurnEnd(t, conn) + close(turnEndSeen) + + if turnEnd.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected TurnEnd finish reason %q, got %q", AgentFinishReasonFailed, turnEnd.FinishReason) + } + if turnEnd.SnapshotID != "" { + t.Errorf("failed turn must not snapshot its partial state, got snapshot %q", turnEnd.SnapshotID) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } +} + +func TestAgent_CustomAgentContinuesAfterFailedTurn(t *testing.T) { + // A custom agent may treat a turn failure as recoverable: swallow the + // error from Run and keep processing queued inputs. The intake must + // keep pacing inputs after a failed turn for this to work. + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "continueAfterFail", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + for { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + text := input.Message.Content[0].Text + if text == "boom" { + return nil, fmt.Errorf("recoverable failure") + } + sess.AddMessages(ai.NewModelTextMessage("echo: " + text)) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }) + if err == nil { + break // input channel closed + } + } + return sess.Result(), nil + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + for _, text := range []string{"one", "boom", "two"} { + if err := conn.SendText(text); err != nil { + t.Fatalf("SendText(%q): %v", text, err) + } + } + + // A hang here means intake pacing after a failed turn is broken. + out, err := outputWithin(t, conn, 2*time.Second) + if err != nil { + t.Fatalf("Output: %v", err) + } + // The agent recovered: the invocation succeeds with the live state + // (including the failed turn's user message, which the agent chose + // to keep by continuing). + if out.FinishReason != AgentFinishReasonStop { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonStop, out.FinishReason) + } + if out.Error != nil { + t.Errorf("expected no error on recovered invocation, got %+v", out.Error) + } + // user one, echo one, user boom, user two, echo two. + if got := len(out.State.Messages); got != 5 { + t.Errorf("expected 5 messages, got %d", got) + } +} + +func TestAgent_InitFailure_ResolvesFailedOutputWithStatus(t *testing.T) { + // Pre-turn precondition/validation failures resolve as failed outputs + // carrying the original error status, rather than failing the action. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + echo := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, nil + }) + } + serverManaged := DefineCustomAgent(reg, "initFailServer", echo, WithSessionStore[testState](store)) + clientManaged := DefineCustomAgent(reg, "initFailClient", echo) + + tests := []struct { + name string + run func() (*AgentOutput[testState], error) + wantStatus core.StatusName + wantMsg string + }{ + { + name: "state rejected when server-managed", + run: func() (*AgentOutput[testState], error) { + return serverManaged.RunText(ctx, "hi", WithState(&SessionState[testState]{})) + }, + wantStatus: core.FAILED_PRECONDITION, + wantMsg: "session store", + }, + { + name: "snapshot ID rejected when client-managed", + run: func() (*AgentOutput[testState], error) { + return clientManaged.RunText(ctx, "hi", WithSnapshotID[testState]("some-id")) + }, + wantStatus: core.FAILED_PRECONDITION, + wantMsg: "no session store", + }, + { + name: "missing snapshot", + run: func() (*AgentOutput[testState], error) { + return serverManaged.RunText(ctx, "hi", WithSnapshotID[testState]("nope")) + }, + wantStatus: core.NOT_FOUND, + wantMsg: "not found", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + out, err := tc.run() + if err != nil { + t.Fatalf("expected graceful failed output, got error: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.Error == nil { + t.Fatal("expected error on output") + } + if out.Error.Status != tc.wantStatus { + t.Errorf("expected status %q, got %q", tc.wantStatus, out.Error.Status) + } + if !strings.Contains(out.Error.Message, tc.wantMsg) { + t.Errorf("expected error message to contain %q, got %q", tc.wantMsg, out.Error.Message) + } + if out.SnapshotID != "" || out.State != nil { + t.Errorf("init failures carry no state, got snapshotId=%q state=%+v", out.SnapshotID, out.State) + } + }) } } @@ -1379,17 +1882,26 @@ func TestPromptAgent_RejectsNonUserRole(t *testing.T) { ai.DefinePrompt(reg, "rejectRolePrompt", ai.WithModelName("test/echo")) af := DefineAgent[testState](reg, "rejectRolePrompt", FromPrompt()) - _, err := af.Run(ctx, &AgentInput{ + out, err := af.Run(ctx, &AgentInput{ Message: &ai.Message{ Role: ai.RoleModel, Content: []*ai.Part{ai.NewTextPart("hi")}, }, }) - if err == nil { - t.Fatal("expected error for non-user role, got nil") + if err != nil { + t.Fatalf("Run: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) } - if !strings.Contains(err.Error(), "role") { - t.Errorf("expected role-related error, got %v", err) + if out.Error == nil { + t.Fatal("expected output error for non-user role, got nil") + } + if out.Error.Status != core.INVALID_ARGUMENT { + t.Errorf("expected status %q, got %q", core.INVALID_ARGUMENT, out.Error.Status) + } + if !strings.Contains(out.Error.Message, "role") { + t.Errorf("expected role-related error, got %v", out.Error) } } @@ -1400,7 +1912,7 @@ func TestPromptAgent_RejectsToolRequestPart(t *testing.T) { ai.DefinePrompt(reg, "rejectToolReqPrompt", ai.WithModelName("test/echo")) af := DefineAgent[testState](reg, "rejectToolReqPrompt", FromPrompt()) - _, err := af.Run(ctx, &AgentInput{ + out, err := af.Run(ctx, &AgentInput{ Message: &ai.Message{ Role: ai.RoleUser, Content: []*ai.Part{ @@ -1409,11 +1921,14 @@ func TestPromptAgent_RejectsToolRequestPart(t *testing.T) { }, }, }) - if err == nil { - t.Fatal("expected error for tool request part, got nil") + if err != nil { + t.Fatalf("Run: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) } - if !strings.Contains(err.Error(), "tool request") { - t.Errorf("expected tool-request error, got %v", err) + if out.Error == nil || !strings.Contains(out.Error.Message, "tool request") { + t.Errorf("expected tool-request error, got %+v", out.Error) } } @@ -1424,7 +1939,7 @@ func TestPromptAgent_RejectsToolResponsePart(t *testing.T) { ai.DefinePrompt(reg, "rejectToolRespPrompt", ai.WithModelName("test/echo")) af := DefineAgent[testState](reg, "rejectToolRespPrompt", FromPrompt()) - _, err := af.Run(ctx, &AgentInput{ + out, err := af.Run(ctx, &AgentInput{ Message: &ai.Message{ Role: ai.RoleUser, Content: []*ai.Part{ @@ -1432,11 +1947,14 @@ func TestPromptAgent_RejectsToolResponsePart(t *testing.T) { }, }, }) - if err == nil { - t.Fatal("expected error for tool response part, got nil") + if err != nil { + t.Fatalf("Run: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) } - if !strings.Contains(err.Error(), "tool") { - t.Errorf("expected tool-related error, got %v", err) + if out.Error == nil || !strings.Contains(out.Error.Message, "tool") { + t.Errorf("expected tool-related error, got %+v", out.Error) } } @@ -1600,7 +2118,7 @@ func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { // TestAgent_FnPanicReturnsError verifies that a panic inside the agent // function is recovered and surfaced as an error, rather than crashing the // process or hanging the streaming goroutine. -func TestAgent_FnPanicReturnsError(t *testing.T) { +func TestAgent_FnPanicResolvesAsFailedOutput(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) @@ -1621,32 +2139,20 @@ func TestAgent_FnPanicReturnsError(t *testing.T) { t.Fatalf("SendText: %v", err) } - done := make(chan error, 1) - go func() { - for chunk, err := range conn.Receive() { - _ = chunk - if err != nil { - done <- err - return - } - } - _, outErr := conn.Output() - done <- outErr - }() - - select { - case err := <-done: - if err == nil { - t.Fatal("expected error from panicking fn") - } - if !strings.Contains(err.Error(), "panicked") { - t.Errorf("expected panic error, got: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatal("Receive/Output hung; streaming goroutine likely leaked") + // A hang here means the streaming goroutine leaked. + out, err := outputWithin(t, conn, 2*time.Second) + if err != nil { + t.Fatalf("expected panic to resolve as failed output, got error: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.Error == nil || !strings.Contains(out.Error.Message, "panicked") { + t.Errorf("expected panic error on output, got: %+v", out.Error) + } + if out.Error != nil && out.Error.Status != core.INTERNAL { + t.Errorf("expected status %q, got %q", core.INTERNAL, out.Error.Status) } - - conn.Close() } // TestAgent_CancelDuringStreamReleasesGoroutine verifies that cancelling the @@ -1743,6 +2249,29 @@ func nextTurnEnd[Stream, State any](t *testing.T, conn *AgentConnection[Stream, return nil } +// outputWithin finalizes conn and returns its output, failing the test if +// finalization does not complete within d. Use it in tests where a +// regression would make Output hang rather than fail. +func outputWithin[Stream, State any](t *testing.T, conn *AgentConnection[Stream, State], d time.Duration) (*AgentOutput[State], error) { + t.Helper() + type outcome struct { + out *AgentOutput[State] + err error + } + done := make(chan outcome, 1) + go func() { + out, err := conn.Output() + done <- outcome{out, err} + }() + select { + case oc := <-done: + return oc.out, oc.err + case <-time.After(d): + t.Fatal("Output did not complete in time; the runtime likely hung") + return nil, nil + } +} + func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { // Sanity: each TurnEnd chunk carries the snapshot ID of the turn-end // snapshot, and the snapshots themselves are persisted. @@ -2009,12 +2538,18 @@ func TestAgent_Detach_RequiresStore(t *testing.T) { } conn.Close() - _, err = conn.Output() - if err == nil { - t.Fatal("expected error when detaching without a session store") + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) } - if !strings.Contains(err.Error(), "detach requires a session store") { - t.Errorf("unexpected error: %v", err) + if out.Error == nil || !strings.Contains(out.Error.Message, "detach requires a session store") { + t.Errorf("unexpected output error: %+v", out.Error) + } + if out.Error != nil && out.Error.Status != core.FAILED_PRECONDITION { + t.Errorf("expected status %q, got %q", core.FAILED_PRECONDITION, out.Error.Status) } } @@ -2250,13 +2785,17 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { t.Errorf("expected snapshot.Error.Message to contain %q, got %+v", "kaboom", snap.Error) } - // Resuming from an errored detached snapshot is rejected. - _, err = af.RunText(context.Background(), "retry", WithSnapshotID[testState](out.SnapshotID)) - if err == nil { - t.Fatal("expected error when resuming from errored snapshot") + // Resuming from an errored detached snapshot is rejected; the + // rejection resolves as a failed output carrying the original error. + resumeOut, err := af.RunText(context.Background(), "retry", WithSnapshotID[testState](out.SnapshotID)) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if resumeOut.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, resumeOut.FinishReason) } - if !strings.Contains(err.Error(), "kaboom") { - t.Errorf("unexpected resume error: %v", err) + if resumeOut.Error == nil || !strings.Contains(resumeOut.Error.Message, "kaboom") { + t.Errorf("unexpected resume error: %+v", resumeOut.Error) } } @@ -2470,12 +3009,21 @@ func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { WithSessionStore(store), ) - _, err := af.RunText(context.Background(), "hi", WithSnapshotID[testState](erroredID)) - if err == nil { - t.Fatal("expected error when resuming from errored snapshot") + out, err := af.RunText(context.Background(), "hi", WithSnapshotID[testState](erroredID)) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) + } + if out.Error == nil { + t.Fatal("expected output error when resuming from errored snapshot") + } + if out.Error.Status != core.FAILED_PRECONDITION { + t.Errorf("expected status %q, got %q", core.FAILED_PRECONDITION, out.Error.Status) } - if !strings.Contains(err.Error(), "underlying failure") { - t.Errorf("expected error to surface underlying failure, got: %v", err) + if !strings.Contains(out.Error.Message, "underlying failure") { + t.Errorf("expected error to surface underlying failure, got: %v", out.Error) } } diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 261ce2c7c4..06d931ea78 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -150,18 +150,30 @@ type AgentMetadata struct { type AgentOutput[State any] struct { // Artifacts contains artifacts produced during the session. Artifacts []*Artifact `json:"artifacts,omitempty"` + // Error is the structured failure information when the invocation ended in + // failure (FinishReason is [AgentFinishReasonFailed]). Its Status preserves + // the original error category (e.g. INVALID_ARGUMENT, FAILED_PRECONDITION, + // INTERNAL) so callers can still branch on it. Nil otherwise. + Error *core.GenkitError `json:"error,omitempty"` // FinishReason is why the invocation finished. It is // [AgentFinishReasonDetached] when the client detached and the work continues - // in the background; otherwise it is the last turn's reason (or the value a - // custom agent set on its [AgentResult]). + // in the background, or [AgentFinishReasonFailed] when the invocation ended + // in failure (see [AgentOutput.Error]); otherwise it is the last turn's + // reason (or the value a custom agent set on its [AgentResult]). FinishReason AgentFinishReason `json:"finishReason,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` // SnapshotID is the ID of the snapshot created at the end of this invocation. // Empty if no snapshot was created (callback returned false or no store configured). + // When FinishReason is [AgentFinishReasonFailed] (and a store is configured), + // it is the most recent snapshot capturing the last-good state: everything + // through the last successful turn (see [SnapshotEventRecovery]). SnapshotID string `json:"snapshotId,omitempty"` // State contains the final conversation state. // Only populated when state is client-managed (no store configured). + // When FinishReason is [AgentFinishReasonFailed], it is the last-good state: + // everything through the last successful turn, excluding the failed turn's + // partial mutations. State *SessionState[State] `json:"state,omitempty"` } @@ -278,12 +290,6 @@ type SessionSnapshot[State any] struct { // [SessionSnapshot.Status] (the persistence lifecycle) so a resumed or // background task can report how it ended without re-deriving it from the // messages. - // - // On an aborted snapshot this is best-effort and may briefly lag Status: - // AbortSnapshot flips Status to [SnapshotStatusAborted] immediately, while the - // [AgentFinishReasonAborted] finish reason is stamped by a subsequent finalizer - // write. If the process running the invocation never finalizes, the reason can - // remain empty even though Status is [SnapshotStatusAborted]. FinishReason AgentFinishReason `json:"finishReason,omitempty"` // ParentID is the ID of the previous snapshot in this timeline. ParentID string `json:"parentId,omitempty"` @@ -328,6 +334,13 @@ const ( // initially written with [SnapshotStatusPending] and rewritten with a // terminal status once the background work finishes. SnapshotEventDetach SnapshotEvent = "detach" + // Recovery indicates the snapshot was written retroactively by the failure + // path to preserve the last-good state (everything through the last + // successful turn) when a selective snapshot callback had skipped + // persisting it. It is a normal [SnapshotStatusSucceeded] row carrying the + // last good turn's finish reason, resumable like any other; the snapshot + // callback is bypassed and never sees this event. + SnapshotEventRecovery SnapshotEvent = "recovery" ) // SnapshotStatus describes the lifecycle state of a snapshot. Snapshots @@ -368,6 +381,11 @@ type TurnEnd struct { // [AgentFinishReasonInterrupted]). It lets a caller react to a turn boundary // (such as pausing on an interrupt) without scanning the message content. // Empty when the turn reported no reason. + // + // [AgentFinishReasonFailed] reports a failed turn; unless the agent + // recovers and keeps processing, the invocation then resolves with a failed + // [AgentOutput] carrying the error and the last-good state, and further + // sends fail with [core.ErrActionCompleted]. FinishReason AgentFinishReason `json:"finishReason,omitempty"` // SnapshotID is the ID of the snapshot persisted at the end of this turn. // Empty if no snapshot was created (callback returned false or no store diff --git a/go/ai/exp/teststore_test.go b/go/ai/exp/teststore_test.go index 94939303a3..4a0e014c32 100644 --- a/go/ai/exp/teststore_test.go +++ b/go/ai/exp/teststore_test.go @@ -50,6 +50,13 @@ func newTestInMemStore[State any]() *testInMemStore[State] { } } +// snapshotCount reports the number of stored snapshot rows. +func (s *testInMemStore[State]) snapshotCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.snapshots) +} + func (s *testInMemStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { s.mu.RLock() defer s.mu.RUnlock() diff --git a/go/core/schemas.config b/go/core/schemas.config index 6eb306ad8a..60a4eb3573 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1325,11 +1325,17 @@ It wraps AgentResult with framework-managed fields. AgentOutput.snapshotId doc SnapshotID is the ID of the snapshot created at the end of this invocation. Empty if no snapshot was created (callback returned false or no store configured). +When FinishReason is [AgentFinishReasonFailed] (and a store is configured), +it is the most recent snapshot capturing the last-good state: everything +through the last successful turn (see [SnapshotEventRecovery]). . AgentOutput.state doc State contains the final conversation state. Only populated when state is client-managed (no store configured). +When FinishReason is [AgentFinishReasonFailed], it is the last-good state: +everything through the last successful turn, excluding the failed turn's +partial mutations. . AgentOutput.message type *ai.Message @@ -1344,10 +1350,24 @@ Artifacts contains artifacts produced during the session. AgentOutput.finishReason doc FinishReason is why the invocation finished. It is [AgentFinishReasonDetached] when the client detached and the work continues -in the background; otherwise it is the last turn's reason (or the value a -custom agent set on its [AgentResult]). +in the background, or [AgentFinishReasonFailed] when the invocation ended +in failure (see [AgentOutput.Error]); otherwise it is the last turn's +reason (or the value a custom agent set on its [AgentResult]). . +AgentOutput.error type *core.GenkitError +AgentOutput.error doc +Error is the structured failure information when the invocation ended in +failure (FinishReason is [AgentFinishReasonFailed]). Its Status preserves +the original error category (e.g. INVALID_ARGUMENT, FAILED_PRECONDITION, +INTERNAL) so callers can still branch on it. Nil otherwise. +. + +# AgentError mirrors the GenkitError wire shape. The fields that carry it +# (AgentOutput.error, SessionSnapshot.error) are overridden to +# *core.GenkitError, so the synthesized type is unused. +AgentError omit + # ---------------------------------------------------------------------------- # AgentStreamChunk # ---------------------------------------------------------------------------- @@ -1406,6 +1426,11 @@ FinishReason is why this turn finished (e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted]). It lets a caller react to a turn boundary (such as pausing on an interrupt) without scanning the message content. Empty when the turn reported no reason. + +[AgentFinishReasonFailed] reports a failed turn; unless the agent +recovers and keeps processing, the invocation then resolves with a failed +[AgentOutput] carrying the error and the last-good state, and further +sends fail with [core.ErrActionCompleted]. . # ---------------------------------------------------------------------------- @@ -1486,12 +1511,6 @@ ended (e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], [SessionSnapshot.Status] (the persistence lifecycle) so a resumed or background task can report how it ended without re-deriving it from the messages. - -On an aborted snapshot this is best-effort and may briefly lag Status: -AbortSnapshot flips Status to [SnapshotStatusAborted] immediately, while the -[AgentFinishReasonAborted] finish reason is stamped by a subsequent finalizer -write. If the process running the invocation never finalizes, the reason can -remain empty even though Status is [SnapshotStatusAborted]. . SessionSnapshot.error type *core.GenkitError @@ -1500,11 +1519,6 @@ Error is the structured failure information for a snapshot in [SnapshotStatusFailed]. Nil otherwise. . -# The synthesized SessionSnapshotError type mirrors the GenkitError wire -# shape inline in the JSON schema. SessionSnapshot.error is overridden to -# *core.GenkitError above, so the synthesized stub is unused. -SessionSnapshotError omit - SessionSnapshot.state doc State is the conversation state captured at this point. Nil on a pending snapshot (the live state is not yet committed; the background @@ -1537,6 +1551,15 @@ initially written with [SnapshotStatusPending] and rewritten with a terminal status once the background work finishes. . +SnapshotEventRecovery doc +Recovery indicates the snapshot was written retroactively by the failure +path to preserve the last-good state (everything through the last +successful turn) when a selective snapshot callback had skipped +persisting it. It is a normal [SnapshotStatusSucceeded] row carrying the +last good turn's finish reason, resumable like any other; the snapshot +callback is bypassed and never sees this event. +. + # ---------------------------------------------------------------------------- # Snapshot lifecycle types # ---------------------------------------------------------------------------- diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index f746283060..986cde7405 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -363,6 +363,11 @@ repl: } fmt.Println() fmt.Println() + if chunk.TurnEnd.FinishReason == aix.AgentFinishReasonFailed { + // A failed turn ends the invocation; Output below + // reports the error and the last-good snapshot. + break repl + } break } } @@ -379,6 +384,16 @@ repl: fmt.Println("The agent keeps processing in the background. Pick this") fmt.Println("agent again from the list to wait for it to finalize and") fmt.Println("resume from the cumulative final state.") + case out != nil && out.FinishReason == aix.AgentFinishReasonFailed: + // A failed invocation resolves with the error and a last-good + // snapshot to resume from. + if out.Error != nil { + fmt.Fprintf(os.Stderr, "Agent failed (%s): %s\n", out.Error.Status, out.Error.Message) + } + if out.SnapshotID != "" { + fmt.Printf("Last-good snapshot: %s. Pick this agent again to resume from it.\n", + shortID(out.SnapshotID)) + } case out != nil && out.SnapshotID != "": fmt.Printf("Done (%s). Final snapshot: %s.\n", out.FinishReason, shortID(out.SnapshotID)) } diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index cfcf4a4a84..e0a8fc61c0 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -61,6 +61,7 @@ class SnapshotEvent(StrEnum): TURNEND = 'turnEnd' INVOCATIONEND = 'invocationEnd' DETACH = 'detach' + RECOVERY = 'recovery' class SnapshotStatus(StrEnum): @@ -132,6 +133,15 @@ class AbortSnapshotResponse(GenkitModel): status: SnapshotStatus | None = None +class AgentError(GenkitModel): + """Model for agenterror data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + status: str | None = None + message: str = Field(...) + details: Any | None = Field(default=None) + + class AgentInit(GenkitModel): """Model for agentinit data.""" @@ -166,6 +176,7 @@ class AgentOutput(GenkitModel): message: MessageData | None = None artifacts: list[Artifact] | None = None finish_reason: AgentFinishReason | None = None + error: AgentError | None = None class AgentResult(GenkitModel): @@ -227,7 +238,7 @@ class SessionSnapshot(GenkitModel): event: SnapshotEvent = Field(...) status: SnapshotStatus | None = None finish_reason: AgentFinishReason | None = None - error: Error | None = None + error: AgentError | None = None state: SessionState | None = None @@ -955,15 +966,6 @@ class Resume(GenkitModel): metadata: Metadata | None = None -class Error(GenkitModel): - """Model for error data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) - status: str | None = None - message: str = Field(...) - details: Any | None = Field(default=None) - - class Details(GenkitModel): """Model for details data.""" @@ -1003,6 +1005,13 @@ class Supports(GenkitModel): long_running: bool | None = None +class Error(GenkitModel): + """Model for error data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='allow', populate_by_name=True) + message: str = Field(...) + + class Resource(GenkitModel): """Model for resource data.""" From 67ae70427142b22ed15ee6989b2fbf820801d748 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 10 Jun 2026 10:16:43 -0700 Subject: [PATCH 090/141] fix(go/core): camelCase json tags for ActionDesc schema fields The schema-field rename left the json tags all-lowercase (inputschema, outputschema, ...), unlike every other camelCase field on the reflection wire. Align them: inputSchema, outputSchema, streamSchema, initSchema. --- go/core/api/action.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index fdb8ed18cb..8fac087abb 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -75,9 +75,9 @@ type ActionDesc struct { Key string `json:"key"` // Key of the action. Name string `json:"name"` // Name of the action. Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputschema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputschema"` // JSON schema to validate against the action's output. - StreamSchema map[string]any `json:"streamschema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. - InitSchema map[string]any `json:"initschema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). + InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. + StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. + InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). Metadata map[string]any `json:"metadata"` // Metadata for the action. } From 059d9e4bf4cb2fffdd855824638302bf7f108def Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 10 Jun 2026 10:17:05 -0700 Subject: [PATCH 091/141] feat(go/exp): session IDs and resume-by-session via WithSessionID Conversations now have a stable identity. The runtime mints a session ID when a conversation's first invocation starts; every snapshot row carries it (SessionSnapshot.sessionId) and the persisted state blob mirrors it (SessionState.sessionId), so server-managed callers track a conversation by ID while client-managed callers get the identity inside the opaque state object they already round-trip, with no separate field to thread through. - WithSessionID / AgentInit.sessionId resumes a session from its latest snapshot: the most recently updated row that is not a failed/aborted dead end. No lineage traversal or linearity check; stores implement SnapshotReader.GetLatestSnapshot as a single status-filtered max-UpdatedAt query, and a forked history (an earlier snapshot resumed again) simply continues the most recently updated branch. ParentID remains as informational lineage only. - A pending latest row (a detached invocation still running) is surfaced, not skipped: the resume is rejected so it cannot race the background writer; wait for it to finalize or abort it. Failed and aborted rows are skipped, so a failed invocation resumes from its recovery snapshot (or the last good turn-end row). - Server-managed: fresh invocations mint an ID, resumes inherit the chain's (rows from before session IDs get a fresh one), and WithSessionID combined with WithSnapshotID asserts the snapshot belongs to that session. - Client-managed: the identity rides inside the state; an ID carried by the incoming state is kept, otherwise one is minted so the first output is already self-describing. AgentInit.state is now mutually exclusive with sessionId and snapshotId, and WithSessionID without a store is rejected, pointing at SessionState.SessionID. - Outbound state is re-stamped with the canonical ID after WithStateTransform runs (getSnapshot action and client-managed outputs), so a redaction transform cannot silently strip the conversation's identity. - SaveSnapshot contract: stores preserve a row's sessionId on rewrite and never mint or infer one (no parent-row inheritance). - The getSnapshot companion action returns the normalized SessionSnapshot row itself; the separate GetSnapshotResponse wire type is gone. - Session.SessionID() exposes the ID to agent code (including via SessionFromContext) as a stable key for external resources. - localstore: GetLatestSnapshot for both stores; the file store scans newest-first by mtime with an early exit, so resolving the most recently active session costs one read. - basic-agents sample resumes whole sessions by ID, falling back to the exact snapshot when the session cannot be resumed as a whole. - jsonschemagen preserves leading whitespace in schemas.config doc blocks so godoc lists survive generation; wire schema and generated artifacts (gen.go, genkit-schema.json, _typing.py) regenerated. --- genkit-tools/common/src/types/agent.ts | 125 +- genkit-tools/genkit-schema.json | 38 +- go/ai/exp/agent.go | 403 ++++--- go/ai/exp/agent_test.go | 1069 +++++++++++++++-- go/ai/exp/gen.go | 157 +-- go/ai/exp/localstore/file.go | 125 +- go/ai/exp/localstore/file_test.go | 43 + go/ai/exp/localstore/inmemory.go | 35 +- go/ai/exp/localstore/inmemory_test.go | 36 + go/ai/exp/localstore/store_test.go | 208 ++++ go/ai/exp/option.go | 69 +- go/ai/exp/session.go | 93 +- go/ai/exp/teststore_test.go | 27 + go/core/flow.go | 5 +- go/core/schemas.config | 186 +-- go/genkit/genkit.go | 3 +- .../cmd/jsonschemagen/jsonschemagen.go | 8 +- go/samples/basic-agents/cli.go | 21 +- .../genkit/src/genkit/_core/_typing.py | 17 +- 19 files changed, 2141 insertions(+), 527 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 300c2e7e90..90b77eda07 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -63,8 +63,9 @@ export type SnapshotEvent = z.infer; * Zod schema for a snapshot's lifecycle status. * * - `pending`: a detached invocation is still processing the queued inputs. - * The snapshot's state is empty until the flow exits, at which point it - * is rewritten with the cumulative final state and a terminal status. + * The snapshot's state is empty until the background work finishes, at + * which point it is rewritten with the cumulative final state and a + * terminal status. * - `succeeded`: the snapshot captures a settled state. * - `aborted`: the snapshot's invocation was aborted via the * `abortSnapshot` companion action while detached. @@ -119,6 +120,15 @@ export type AgentFinishReason = z.infer; * Zod schema for session state. */ export const SessionStateSchema = z.object({ + /** + * ID of the session (conversation) this state belongs to. + * Framework-owned: assigned when the conversation's first invocation + * starts and re-stamped on outbound state, so client-managed callers + * can round-trip the state object opaquely without tracking a separate + * identifier. For server-managed agents the snapshot row's `sessionId` + * is canonical and this field mirrors it. + */ + sessionId: z.string().optional(), /** Conversation history (user/model exchanges). */ messages: z.array(MessageSchema).optional(), /** User-defined state associated with this conversation. */ @@ -158,9 +168,35 @@ export type AgentInput = z.infer; * Zod schema for agent initialization. */ export const AgentInitSchema = z.object({ - /** Loads state from a persisted snapshot. Mutually exclusive with state. */ + /** + * Identifies the session (conversation) to resume. Only valid when the + * agent is server-managed (a session store is configured); mutually + * exclusive with state (a client-managed conversation carries its + * identity inside `state.sessionId`). Alone it resumes the session + * from its latest snapshot: the most recently updated one that is not + * a failed/aborted dead end. A pending latest snapshot (a detached + * invocation still running) rejects the resume rather than racing the + * background work; if the session's history was forked by re-resuming + * an earlier snapshot, the most recently updated branch wins, and + * snapshotId can pick a branch explicitly. Combined with snapshotId, + * it asserts which session the snapshot belongs to, and a mismatch is + * rejected. + */ + sessionId: z.string().optional(), + /** + * Loads state from a persisted snapshot (server-managed state only). + * May be combined with sessionId to validate that the snapshot belongs + * to that session. Mutually exclusive with state. + */ snapshotId: z.string().optional(), - /** Direct state for the invocation. Mutually exclusive with snapshotId. */ + /** + * Direct state for the invocation (client-managed state only). The + * conversation's identity rides inside it (`state.sessionId`): the + * framework mints one on the conversation's first invocation and + * echoes it on the output state, so resending the state object keeps + * the identity without tracking a separate field. Mutually exclusive + * with sessionId and snapshotId. + */ state: SessionStateSchema.optional(), }); export type AgentInit = z.infer; @@ -200,10 +236,24 @@ export type AgentError = z.infer; */ export const AgentOutputSchema = z.object({ /** - * ID of the snapshot created at the end of this invocation. When - * `finishReason` is `failed` (and a store is configured), this is the - * most recent snapshot capturing the last-good state: everything - * through the last successful turn (see the `recovery` snapshot event). + * ID of the session this invocation belongs to, assigned by the + * framework when the invocation starts. With server-managed state, a + * fresh invocation mints a new ID, resumed invocations inherit the + * chain's, and resuming a snapshot from before session IDs existed + * mints a fresh one. With client-managed state it echoes the ID + * carried inside the state object (`state.sessionId`), minting one on + * the conversation's first invocation; only a session with persisted + * snapshots can be resumed by this ID. + */ + sessionId: z.string().optional(), + /** + * ID of the newest snapshot capturing this invocation: the + * invocation-end snapshot, or the latest earlier snapshot when that + * write was skipped. Empty when no store is configured or the + * invocation persisted nothing. When `finishReason` is `detached` it + * is the pending detach snapshot; when `failed`, the most recent + * snapshot capturing the last-good state: everything through the last + * successful turn (see the `recovery` snapshot event). */ snapshotId: z.string().optional(), /** @@ -245,8 +295,9 @@ export type AgentOutput = z.infer; export const TurnEndSchema = z.object({ /** * ID of the snapshot persisted at the end of this turn. Empty if no - * snapshot was created (callback returned false, no store configured, or - * snapshots were suspended after detach). + * snapshot was written (no store configured, the callback declined, + * nothing changed since the last snapshot, or snapshots were suspended + * after detach). */ snapshotId: z.string().optional(), /** @@ -289,7 +340,20 @@ export type AgentStreamChunk = z.infer; export const SessionSnapshotSchema = z.object({ /** Unique identifier for this snapshot (UUID). */ snapshotId: z.string(), - /** ID of the previous snapshot in this timeline. */ + /** + * ID of the session this snapshot belongs to. Assigned by the agent + * framework when the conversation's first invocation starts and + * stamped on every later snapshot in the chain, including across + * resumed invocations. Stores preserve it across rewrites; rows + * written without one (data from before session IDs existed) belong + * to no session. + */ + sessionId: z.string().optional(), + /** + * ID of the previous snapshot in this timeline. Informational lineage + * (for debugging and UI history trees); plays no part in resolving a + * session's latest snapshot. + */ parentId: z.string().optional(), /** When the snapshot was first written (RFC 3339). */ createdAt: z.string(), @@ -319,8 +383,10 @@ export type SessionSnapshot = z.infer; /** * Zod schema for the input of an agent's `getSnapshot` companion action. - * The action is registered at `{agentName}/getSnapshot` when the agent - * is defined. + * The action is registered under the agent's name (action type + * `agent-snapshot`) when the agent has a session store configured. The + * action returns the stored `SessionSnapshot`, with any configured state + * transform applied to its state. */ export const GetSnapshotRequestSchema = z.object({ /** Identifies the snapshot to fetch. */ @@ -328,37 +394,6 @@ export const GetSnapshotRequestSchema = z.object({ }); export type GetSnapshotRequest = z.infer; -/** - * Zod schema for the output of the `getSnapshot` companion action. It is a - * client-facing view of the stored snapshot: identifying metadata plus the - * session state, with `WithStateTransform` applied if configured. - */ -export const GetSnapshotResponseSchema = z.object({ - /** Echoes the requested snapshot ID. */ - snapshotId: z.string(), - /** When the snapshot record was first written (RFC 3339). */ - createdAt: z.string().optional(), - /** When the snapshot record was last written (RFC 3339). */ - updatedAt: z.string().optional(), - /** Lifecycle state of the snapshot. */ - status: SnapshotStatusSchema.optional(), - /** - * Semantic reason the captured turn or invocation ended (e.g. `stop`, - * `interrupted`, `failed`, `aborted`). Lets a remote or background poller - * report how a detached/resumed invocation ended without re-deriving it. - * Empty on a pending snapshot. - */ - finishReason: AgentFinishReasonSchema.optional(), - /** Structured failure information; populated when status is `error`. */ - error: z.any().optional(), - /** - * Session state captured by the snapshot, after any configured transform. - * Empty when status is `pending` or `error`. - */ - state: SessionStateSchema.optional(), -}); -export type GetSnapshotResponse = z.infer; - /** * Zod schema for the input of the `abortSnapshot` companion action. */ @@ -376,7 +411,7 @@ export const AbortSnapshotResponseSchema = z.object({ snapshotId: z.string(), /** * Snapshot's status after the abort attempt. For a pending snapshot - * this is `canceled`. For an already-terminal snapshot this is the + * this is `aborted`. For an already-terminal snapshot this is the * existing terminal status (the abort is a no-op). */ status: SnapshotStatusSchema.optional(), diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 62e9127891..b3b09fce12 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -61,6 +61,9 @@ "AgentInit": { "type": "object", "properties": { + "sessionId": { + "type": "string" + }, "snapshotId": { "type": "string" }, @@ -119,6 +122,9 @@ "AgentOutput": { "type": "object", "properties": { + "sessionId": { + "type": "string" + }, "snapshotId": { "type": "string" }, @@ -218,38 +224,13 @@ ], "additionalProperties": false }, - "GetSnapshotResponse": { + "SessionSnapshot": { "type": "object", "properties": { "snapshotId": { "type": "string" }, - "createdAt": { - "type": "string" - }, - "updatedAt": { - "type": "string" - }, - "status": { - "$ref": "#/$defs/SnapshotStatus" - }, - "finishReason": { - "$ref": "#/$defs/AgentFinishReason" - }, - "error": {}, - "state": { - "$ref": "#/$defs/SessionState" - } - }, - "required": [ - "snapshotId" - ], - "additionalProperties": false - }, - "SessionSnapshot": { - "type": "object", - "properties": { - "snapshotId": { + "sessionId": { "type": "string" }, "parentId": { @@ -287,6 +268,9 @@ "SessionState": { "type": "object", "properties": { + "sessionId": { + "type": "string" + }, "messages": { "type": "array", "items": { diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 84c05f8962..99c3e5ff34 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -34,6 +34,7 @@ import ( "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" + "github.com/google/uuid" ) // --- SessionRunner --- @@ -54,11 +55,19 @@ type SessionRunner[State any] struct { // directly. TurnIndex int - snapshotCallback SnapshotCallback[State] - onEndTurn func(ctx context.Context) + snapshotCallback SnapshotCallback[State] + onEndTurn func(ctx context.Context) + collectTurnOutput func() any + + // snapMu serializes snapshot persistence with the detach handler's + // suspend-and-capture. lastSnapshot and lastSnapshotVersion are + // written under it; the terminal paths that read them without it + // (handleFnDone, failedOutput) run after fn completes, with a + // happens-before edge through the fnDone channel. + snapMu sync.Mutex + snapshotsSuspended bool lastSnapshot *SessionSnapshot[State] lastSnapshotVersion uint64 - collectTurnOutput func() any // lastTurnFinishReason is the finish reason reported by the most recent // turn (via the [TurnResult] its callback returned), or "" if the turn @@ -99,6 +108,21 @@ func (s *SessionRunner[State]) parentSnapshotID() string { return s.lastSnapshot.SnapshotID } +// suspendSnapshots stops all further snapshot persistence for this +// invocation and returns the ID of the newest persisted snapshot. Taking +// snapMu makes the two steps atomic with respect to an in-flight turn-end +// write: a write already inside maybeSnapshot completes first (so the +// returned parent is current, not stale), and any later turn end observes +// the suspension and skips its write. Called by the detach handler, after +// which the queued inputs roll into a single finalize rewrite of the +// pending row. +func (s *SessionRunner[State]) suspendSnapshots() (parentID string) { + s.snapMu.Lock() + defer s.snapMu.Unlock() + s.snapshotsSuspended = true + return s.parentSnapshotID() +} + // TurnResult is the optional return value of a [SessionRunner.Run] per-turn // callback. It lets a custom agent report how the turn ended; the framework // forwards the reason on the turn's [TurnEnd] chunk, persists it on the @@ -116,7 +140,8 @@ type TurnResult struct { // Run loops over the input channel, calling fn for each turn. Each turn is // wrapped in a trace span for observability. Input messages are automatically // added to the session before fn is called. After fn returns successfully, a -// TurnEnd chunk is sent and a snapshot check is triggered. +// snapshot check is triggered and a [TurnEnd] chunk (carrying any new +// snapshot's ID) is sent. // // fn may return a [TurnResult] to report how the turn ended (e.g. its finish // reason); returning nil reports nothing. The reason rides the turn's @@ -230,14 +255,25 @@ func (s *SessionRunner[State]) invocationReason(result *AgentResult) AgentFinish } // maybeSnapshot creates a snapshot if conditions are met (store configured, -// callback approves, state changed). Returns the snapshot ID or empty -// string. finishReason is recorded on the snapshot so a resumed or -// background task can report how the captured turn or invocation ended. +// snapshots not suspended by detach, callback approves, state changed). +// Returns the snapshot ID or empty string. finishReason is recorded on the +// snapshot so a resumed or background task can report how the captured turn +// or invocation ended. +// +// The body runs under snapMu so the detach handler's suspend-and-capture +// (suspendSnapshots) cannot interleave with a write: it either waits for +// this write to commit or suspends before it starts. func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason) string { if s.store == nil { return "" } + s.snapMu.Lock() + defer s.snapMu.Unlock() + if s.snapshotsSuspended { + return "" + } + s.mu.RLock() currentVersion := s.version currentState := s.copyStateLocked() @@ -271,22 +307,24 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot } } - return s.persistSnapshot(ctx, event, finishReason, ¤tState, currentVersion) + return s.persistSnapshotLocked(ctx, event, finishReason, ¤tState, currentVersion) } -// persistSnapshot writes a succeeded snapshot row capturing state (at the -// given session version), chained to the newest persisted snapshot, and +// persistSnapshotLocked writes a succeeded snapshot row capturing state (at +// the given session version), chained to the newest persisted snapshot, and // advances the lastSnapshot bookkeeping. Both the routine cadence // (maybeSnapshot) and the failure path (recoverySnapshotID) funnel through -// here so the row shape and bookkeeping live in one place. Persistence is -// best-effort: a store failure must not kill the in-flight turn, so it is -// logged and "" is returned. -func (s *SessionRunner[State]) persistSnapshot(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason, state *SessionState[State], version uint64) string { +// here so the row shape and bookkeeping live in one place. Caller must hold +// snapMu. Persistence is best-effort: a store failure must not kill the +// in-flight turn, so it is logged and "" is returned. +func (s *SessionRunner[State]) persistSnapshotLocked(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason, state *SessionState[State], version uint64) string { parentID := s.parentSnapshotID() + sessionID := s.SessionID() saved, err := s.store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { return &SessionSnapshot[State]{ + SessionID: sessionID, ParentID: parentID, Event: event, Status: SnapshotStatusSucceeded, @@ -311,9 +349,10 @@ func (s *SessionRunner[State]) persistSnapshot(ctx context.Context, event Snapsh // state, writing one (event [SnapshotEventRecovery]) when the newest // persisted snapshot is behind it. The write uses the captured // lastGoodState, never the live state (which may hold the failed turn's -// partial mutations), and intentionally bypasses the snapshot callback so -// a selective cadence cannot lose the conversation. If the write fails, -// the newest persisted snapshot's ID is returned instead. +// partial mutations), and intentionally bypasses both the snapshot +// callback and the post-detach suspension, so neither a selective cadence +// nor a dying detach can lose the conversation. If the write fails, the +// newest persisted snapshot's ID is returned instead. // // Returns "" when no store is configured or there is nothing to recover // (no snapshot exists and no turn ever changed state). @@ -321,6 +360,8 @@ func (s *SessionRunner[State]) recoverySnapshotID(ctx context.Context) string { if s.store == nil { return "" } + s.snapMu.Lock() + defer s.snapMu.Unlock() // The newest snapshot already captures exactly the last-good state. if s.lastSnapshot != nil && s.lastGoodVersion == s.lastSnapshotVersion { return s.lastSnapshot.SnapshotID @@ -329,7 +370,7 @@ func (s *SessionRunner[State]) recoverySnapshotID(ctx context.Context) string { return "" } - if id := s.persistSnapshot(ctx, SnapshotEventRecovery, s.lastGoodFinishReason, s.lastGoodState, s.lastGoodVersion); id != "" { + if id := s.persistSnapshotLocked(ctx, SnapshotEventRecovery, s.lastGoodFinishReason, s.lastGoodState, s.lastGoodVersion); id != "" { return id } return s.parentSnapshotID() @@ -444,12 +485,12 @@ func DefineAgent[State any]( } } -// DefineCustomAgent defines an agent with full control over the conversation -// loop and registers it with the registry. The underlying action is created -// via [core.DefineBidiAction] (rather than [core.DefineBidiFlow]) so the -// agent capability metadata can be set at construction time — actions -// must be immutable once registered. The flow-context wrapping that makes -// [core.Run] work inside fn is preserved via [core.WithFlowContext]. +// DefineCustomAgent defines an agent with full control over the +// conversation loop and registers it with the registry. fn receives a +// [Responder] for streaming output and a [SessionRunner] for turn and +// state management; call [SessionRunner.Run] to enter the per-turn loop. +// +// For agents backed by a prompt, use [DefineAgent] instead. func DefineCustomAgent[Stream, State any]( r api.Registry, name string, @@ -463,6 +504,10 @@ func DefineCustomAgent[Stream, State any]( } } + // Built on DefineBidiAction (rather than DefineBidiFlow) so the agent + // capability metadata can be set at construction time; actions must be + // immutable once registered. WithFlowContext below preserves the + // flow-context wrapping that makes core.Run work inside fn. action := core.DefineBidiAction(r, name, api.ActionTypeFlow, &core.ActionOptions{ Metadata: map[string]any{"agent": agentMetadataFor(cfg.store)}, @@ -476,13 +521,13 @@ func DefineCustomAgent[Stream, State any]( ctx = core.WithFlowContext(ctx, name) rt, err := newAgentRuntime(ctx, name, cfg, in, inCh, outCh) if err != nil { - // Init validation failures resolve as a failed output - // too. No state is attached; the invocation never - // started, so the caller lost nothing. - return &AgentOutput[State]{ - FinishReason: AgentFinishReasonFailed, - Error: core.AsGenkitError(err), - }, nil + // Init failures (a rejected init payload, a failed + // snapshot load) fail the action outright rather than + // resolving as a failed output: the invocation never + // reached the input phase of its lifecycle, so there is + // no conversation state to hand back and nothing to + // snapshot. + return nil, err } return rt.run(ctx, fn) }) @@ -547,6 +592,31 @@ func newAgentRuntime[Stream, State any]( return nil, err } + // The session ID is settled up front, before the agent function runs, + // so it exists for the whole invocation regardless of when (or + // whether) the first snapshot is written. It lives inside the session + // state ([SessionState.SessionID]), so it rides along wherever the + // state goes: every persisted snapshot's state, and the state + // returned to (and resent by) client-managed callers. + if cfg.store != nil { + // Server-managed: the store row is canonical. Inherit the resumed + // chain's ID, overriding whatever the loaded state blob claims (a + // third-party writer could have let them drift), or mint one for a + // fresh conversation (including one resumed from a snapshot that + // predates session IDs). + if parent != nil && parent.SessionID != "" { + session.state.SessionID = parent.SessionID + } else { + session.state.SessionID = uuid.New().String() + } + } else if session.state.SessionID == "" { + // Client-managed: the state object is canonical; keep the ID it + // carried. Mint one when absent (a fresh conversation) so the + // output state is self-describing from the first turn and the + // client can round-trip it without tracking a separate field. + session.state.SessionID = uuid.New().String() + } + rt := &agentRuntime[Stream, State]{ name: name, cfg: cfg, @@ -577,14 +647,15 @@ func newAgentRuntime[Stream, State any]( // through the router so clients see it on the output stream. // // The snapshot is skipped when the turn failed (the live state holds the -// turn's partial mutations) and when detach has landed (the pending row -// already captures the invocation; a single finalize rewrite records the -// cumulative state once the queued inputs drain). +// turn's partial mutations) and when detach has landed (maybeSnapshot +// observes the suspension under snapMu; the pending row already captures +// the invocation and a single finalize rewrite records the cumulative +// state once the queued inputs drain). func (rt *agentRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { - suspended := rt.intake.beginTurnEnd() + rt.intake.releaseForward() reason := rt.sess.lastTurnFinishReason var snapshotID string - if !rt.sess.lastTurnFailed && !suspended { + if !rt.sess.lastTurnFailed { snapshotID = rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd, reason) } rt.router.sendChunk(ctx, &AgentStreamChunk[Stream]{TurnEnd: &TurnEnd{ @@ -743,7 +814,11 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( snapshotID = rt.sess.lastSnapshot.SnapshotID } - out := &AgentOutput[State]{SnapshotID: snapshotID, FinishReason: invocationReason} + out := &AgentOutput[State]{ + SessionID: rt.session.SessionID(), + SnapshotID: snapshotID, + FinishReason: invocationReason, + } if res.result != nil { // Deep-copy at the framework boundary so the caller cannot // mutate session contents through the returned output, even @@ -753,11 +828,23 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( out.Artifacts = cloneArtifacts(res.result.Artifacts) } if rt.cfg.store == nil { - out.State = applyTransform(ctx, rt.cfg.transform, rt.session.State()) + out.State = rt.outboundState(ctx, rt.session.State()) } return out, nil } +// outboundState applies the configured state transform and re-stamps the +// framework-owned SessionID, so the state handed to a client-managed +// caller always carries the conversation's identity even if a transform +// rewrote or dropped it. Returns nil if state is nil. +func (rt *agentRuntime[Stream, State]) outboundState(ctx context.Context, state *SessionState[State]) *SessionState[State] { + out := applyTransform(ctx, rt.cfg.transform, state) + if out != nil { + out.SessionID = rt.session.SessionID() + } + return out +} + // failedOutput assembles the output for an invocation that ended in // failure: [AgentFinishReasonFailed], the error with its original status, // and the last-good state (inline when client-managed, behind a recovery @@ -769,10 +856,11 @@ func (rt *agentRuntime[Stream, State]) failedOutput(ctx context.Context, cause e Error: core.AsGenkitError(cause), } if rt.cfg.store == nil { - out.State = applyTransform(ctx, rt.cfg.transform, rt.sess.lastGoodState) + out.State = rt.outboundState(ctx, rt.sess.lastGoodState) } else { out.SnapshotID = rt.sess.recoverySnapshotID(ctx) } + out.SessionID = rt.session.SessionID() return out } @@ -792,9 +880,12 @@ func (rt *agentRuntime[Stream, State]) handleDetach( // fn completion can cancel workCtx. markDetached() - rt.intake.suspend() - - parentID := rt.sess.parentSnapshotID() + // Atomically suspend per-turn snapshots and capture the chain tip: a + // turn-end write already in flight commits first (so the pending row + // chains off the real tip instead of becoming its sibling), and any + // later turn end skips its write. + parentID := rt.sess.suspendSnapshots() + sessionID := rt.session.SessionID() // Detach intends to outlive the client connection. If clientCtx was // already cancelled (or cancels mid-write), we still want the pending @@ -802,9 +893,10 @@ func (rt *agentRuntime[Stream, State]) handleDetach( pending, err := rt.cfg.store.SaveSnapshot(context.WithoutCancel(clientCtx), "", func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { return &SessionSnapshot[State]{ - ParentID: parentID, - Event: SnapshotEventDetach, - Status: SnapshotStatusPending, + SessionID: sessionID, + ParentID: parentID, + Event: SnapshotEventDetach, + Status: SnapshotStatusPending, }, nil }) if err != nil { @@ -812,7 +904,6 @@ func (rt *agentRuntime[Stream, State]) handleDetach( return rt.failedOutput(clientCtx, core.NewError(core.INTERNAL, "agent %q: detach: save pending snapshot: %v", rt.name, err)), nil } - // The router can no longer write to outCh once we return; the bidi // framework closes it shortly after. The router stops writing and // trashes any further chunks. @@ -846,6 +937,7 @@ func (rt *agentRuntime[Stream, State]) handleDetach( // pending snapshot is finalized later with how the background work // actually ended (see finalizePendingSnapshot). return &AgentOutput[State]{ + SessionID: pending.SessionID, SnapshotID: pending.SnapshotID, FinishReason: AgentFinishReasonDetached, }, nil @@ -907,6 +999,7 @@ func (rt *agentRuntime[Stream, State]) finalizePendingSnapshot( } return &SessionSnapshot[State]{ + SessionID: pending.SessionID, ParentID: pending.ParentID, Event: SnapshotEventDetach, Status: status, @@ -922,8 +1015,15 @@ func (rt *agentRuntime[Stream, State]) finalizePendingSnapshot( } // loadSession constructs a Session from the invocation's init payload, -// loading from the store when a snapshot ID is provided. Returns the -// snapshot too so the runtime can chain ParentID off it. +// loading from the store when a snapshot or session ID is provided. +// Returns the loaded snapshot too so the runtime can chain ParentID (and +// carry the session ID) off it. +// +// State is mutually exclusive with both SessionID and SnapshotID: it is +// the client-managed conversation source and carries its own identity +// ([SessionState.SessionID]), while the two IDs resolve against a store. +// SessionID and SnapshotID compose: the snapshot picks the exact resume +// point and the session ID is asserted against it. func loadSession[State any]( ctx context.Context, init *AgentInit[State], @@ -934,32 +1034,68 @@ func loadSession[State any]( return s, nil, nil } - if init.SnapshotID != "" && init.State != nil { - return nil, nil, core.NewError(core.INVALID_ARGUMENT, "snapshot ID and state are mutually exclusive") + if init.State != nil && (init.SessionID != "" || init.SnapshotID != "") { + return nil, nil, core.NewError(core.INVALID_ARGUMENT, + "state is mutually exclusive with session ID and snapshot ID; a client-managed conversation's identity rides inside the state (SessionState.SessionID)") } - if init.SnapshotID == "" { - if init.State != nil { - if store != nil { - return nil, nil, core.NewError(core.FAILED_PRECONDITION, - "state provided but agent has a session store configured (server-managed state); use snapshot ID instead") - } - s.state = *init.State + switch { + case init.State != nil: + if store != nil { + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "state provided but agent has a session store configured (server-managed state); use snapshot ID instead") } + s.state = *init.State return s, nil, nil - } - if store == nil { - return nil, nil, core.NewError(core.FAILED_PRECONDITION, - "snapshot ID %q provided but agent has no session store configured (client-managed state); use state instead", init.SnapshotID) - } - snap, err := store.GetSnapshot(ctx, init.SnapshotID) - if err != nil { - return nil, nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) - } - if snap == nil { - return nil, nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) + case init.SnapshotID != "": + if store == nil { + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "snapshot ID %q provided but agent has no session store configured (client-managed state); use state instead", init.SnapshotID) + } + snap, err := store.GetSnapshot(ctx, init.SnapshotID) + if err != nil { + return nil, nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) + } + if snap == nil { + return nil, nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) + } + // A session ID sent alongside the snapshot ID asserts which + // conversation the snapshot belongs to; a mismatch means the + // caller would silently continue the wrong conversation. + if init.SessionID != "" && snap.SessionID != init.SessionID { + return nil, nil, core.NewError(core.INVALID_ARGUMENT, + "snapshot %q does not belong to session %q (snapshot's session: %q)", init.SnapshotID, init.SessionID, snap.SessionID) + } + return resumeSessionFrom(s, snap) + + case init.SessionID != "": + if store == nil { + return nil, nil, core.NewError(core.FAILED_PRECONDITION, + "session ID %q provided but agent has no session store configured (client-managed state); the conversation's identity rides inside the state object (SessionState.SessionID)", init.SessionID) + } + snap, err := store.GetLatestSnapshot(ctx, init.SessionID) + if err != nil { + return nil, nil, core.NewError(core.INTERNAL, "failed to resolve latest snapshot for session %q: %v", init.SessionID, err) + } + if snap == nil { + return nil, nil, core.NewError(core.NOT_FOUND, "no resumable snapshot found for session %q", init.SessionID) + } + if snap.SessionID != init.SessionID { + return nil, nil, core.NewError(core.INTERNAL, + "store resolved session %q to snapshot %q, which belongs to session %q; the store violates the GetLatestSnapshot contract", init.SessionID, snap.SnapshotID, snap.SessionID) + } + return resumeSessionFrom(s, snap) } + return s, nil, nil +} + +// resumeSessionFrom validates that snap is in a resumable status and loads +// its state into s. Shared by the snapshot-ID and session-ID init paths; +// the session-ID path can only hit the pending case (a conforming store's +// GetLatestSnapshot never resolves to failed/aborted dead ends), but the +// full switch stays as a defense against non-conforming stores. +func resumeSessionFrom[State any](s *Session[State], snap *SessionSnapshot[State]) (*Session[State], *SessionSnapshot[State], error) { switch snap.Status { case SnapshotStatusFailed: msg := "snapshot recorded an error" @@ -967,13 +1103,13 @@ func loadSession[State any]( msg = snap.Error.Message } return nil, nil, core.NewError(core.FAILED_PRECONDITION, - "snapshot %q terminated with error: %s", init.SnapshotID, msg) + "snapshot %q terminated with error: %s", snap.SnapshotID, msg) case SnapshotStatusPending: return nil, nil, core.NewError(core.FAILED_PRECONDITION, - "snapshot %q is still pending; wait for it to finalize before resuming", init.SnapshotID) + "snapshot %q is still pending: its detached invocation is still running; wait for it to finalize or abort it before resuming", snap.SnapshotID) case SnapshotStatusAborted: return nil, nil, core.NewError(core.FAILED_PRECONDITION, - "snapshot %q was aborted", init.SnapshotID) + "snapshot %q was aborted", snap.SnapshotID) } if snap.State != nil { s.state = *snap.State @@ -1126,29 +1262,23 @@ func (r *chunkRouter[Stream, State]) close() { // The forwarder goroutine pops the queue and writes to dst, blocking on // the runner via turnDone so it stays in step with turn pacing. // -// The runtime's emitTurnEnd asks beginTurnEnd at the end of each turn: -// if suspended (detach has landed), it skips the turn-end snapshot — the -// pending row already captures the invocation and a single finalize -// will rewrite it with the cumulative state once the queued inputs -// drain. If not suspended, a normal turn-end snapshot is written. -// -// suspend is called once by the detach handler under the same mutex -// that beginTurnEnd reads from, ensuring memory ordering: any -// beginTurnEnd that returns after suspend completes sees suspended=true. +// Snapshot suspension after detach is not the intake's concern: the +// runner gates writes itself (see SessionRunner.suspendSnapshots), so a +// detach can atomically wait out an in-flight turn-end write. The intake +// only owns input pacing. type detachIntake struct { src <-chan *AgentInput dst chan *AgentInput notify chan struct{} // buffered size 1; wakes forwarder when queue grows - // turnDone is signaled by beginTurnEnd to release the forwarder so it - // may pop the next input. Initialized with one token so the very + // turnDone is signaled at each turn end to release the forwarder so + // it may pop the next input. Initialized with one token so the very // first turn can start without a preceding turn end. turnDone chan struct{} - mu sync.Mutex - suspended bool - queue []*AgentInput + mu sync.Mutex + queue []*AgentInput readDone atomic.Bool detachCh chan struct{} // signaled by reader when detach observed @@ -1231,8 +1361,8 @@ func (i *detachIntake) enqueue(input *AgentInput) { } // handleDetach drains any buffered src inputs into the queue and signals -// the detach handler. The detach handler then calls suspend to halt -// turn-end snapshots while the queued inputs finish processing. +// the detach handler. The detach handler then suspends turn-end snapshots +// (via the runner) while the queued inputs finish processing. // // A pure detach signal (no Messages, no Resume payload) is dropped // rather than enqueued: it carries no payload to process, so it would @@ -1287,9 +1417,9 @@ func hasInputPayload(in *AgentInput) bool { } // forward pops the queue and writes to dst at the runner's pace. The -// runner signals turnDone via beginTurnEnd when it's ready for the next -// input; until then the forwarder waits, so it never gets ahead of the -// runner. +// runtime signals turnDone via releaseForward when it's ready for the +// next input; until then the forwarder waits, so it never gets ahead of +// the runner. func (i *detachIntake) forward() { for { // Wait for the previous turn to release us (initial credit lets @@ -1339,8 +1469,8 @@ func (i *detachIntake) awaitInput() *AgentInput { } // releaseForward releases the forwarder so it can pop the next input. -// Must be called from beginTurnEnd (and only there) so the forwarder -// stays in step with the runner's turn pacing. +// Called by the runtime's emitTurnEnd at each turn end (and only there) +// so the forwarder stays in step with the runner's turn pacing. func (i *detachIntake) releaseForward() { select { case i.turnDone <- struct{}{}: @@ -1356,31 +1486,6 @@ func (i *detachIntake) detachSignal() <-chan struct{} { return i.detachCh } -// beginTurnEnd is called by the runtime's emitTurnEnd at each turn end, -// before any turn-end snapshot. If the intake has been suspended (detach -// landed), it returns suspended=true and the caller skips the snapshot. -// -// In all cases (including suspended) the forwarder is released so it can -// pop the next queued input — suspension stops snapshot writing, not -// processing. -func (i *detachIntake) beginTurnEnd() (suspended bool) { - i.mu.Lock() - suspended = i.suspended - i.mu.Unlock() - i.releaseForward() - return suspended -} - -// suspend is called once by the detach handler. It flips suspended=true -// under the mutex so subsequent beginTurnEnd calls observe the change -// and skip their turn-end snapshot writes; the queued inputs roll into -// a single finalize rewrite of the pending row instead. -func (i *detachIntake) suspend() { - i.mu.Lock() - i.suspended = true - i.mu.Unlock() -} - // stopAndWait forces the intake to exit and waits for both reader and // forwarder goroutines. func (i *detachIntake) stopAndWait() { @@ -1422,7 +1527,7 @@ func validateUserMessage(m *ai.Message) error { // with streaming, and updates the session. // // defaultInput is the prompt input passed to Render on every turn. It is -// nil for [DefineAgent], where the inline-defined prompt has no per-turn +// nil for inline-defined prompts ([FromInline]), which take no per-turn // input. func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[any, State] { return func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*AgentResult, error) { @@ -1532,8 +1637,9 @@ func (a *Agent[Stream, State]) StreamBidi( // It sends the input, waits for the agent to complete, and returns the output. // For multi-turn interactions or streaming, use StreamBidi instead. // -// In-band failures (a failed turn, a rejected init payload) resolve as a -// failed [AgentOutput] rather than an error; see [AgentConnection.Output]. +// In-band failures (e.g. a failed turn) resolve as a failed [AgentOutput] +// rather than an error; a rejected init payload fails with an error, since +// the invocation never starts. See [AgentConnection.Output]. func (a *Agent[Stream, State]) Run( ctx context.Context, input *AgentInput, @@ -1544,8 +1650,8 @@ func (a *Agent[Stream, State]) Run( return nil, err } // The invocation may resolve before consuming the input (e.g. an init - // validation failure resolves as a failed output); the outcome is on - // Output regardless. + // validation failure errors out before the first turn); the outcome, + // whether output or error, is on Output regardless. if err := conn.Send(input); err != nil && !errors.Is(err, core.ErrActionCompleted) { return nil, err } @@ -1566,6 +1672,11 @@ func (a *Agent[Stream, State]) RunText( } // resolveOptions applies invocation options and returns the init struct. +// Mutual exclusivity is checked here, once, after all options are merged: +// WithState excludes both WithSessionID and WithSnapshotID (a +// client-managed conversation's identity rides inside the state itself), +// while WithSessionID and WithSnapshotID compose as an assertion. +// Per-option duplicate checks live in applyInvocation. func (a *Agent[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*AgentInit[State], error) { invOpts := &invocationOptions[State]{} for _, opt := range opts { @@ -1574,7 +1685,15 @@ func (a *Agent[Stream, State]) resolveOptions(opts []InvocationOption[State]) (* } } + if invOpts.state != nil && invOpts.snapshotID != "" { + return nil, fmt.Errorf("Agent %q: WithState and WithSnapshotID are mutually exclusive", a.action.Name()) + } + if invOpts.state != nil && invOpts.sessionIDSet { + return nil, fmt.Errorf("Agent %q: WithState and WithSessionID are mutually exclusive; the conversation's identity rides inside the state (SessionState.SessionID)", a.action.Name()) + } + return &AgentInit[State]{ + SessionID: invOpts.sessionID, SnapshotID: invOpts.snapshotID, State: invOpts.state, }, nil @@ -1655,32 +1774,26 @@ func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[S // Output finalizes the connection and returns the agent's result. // // Output is the single "I'm done" call: it implicitly closes the input -// side and drains any chunks the caller did not consume via Receive, -// then blocks until the agent finalizes. The drain is required because -// the underlying stream buffer is shallow; without it, a producing fn -// would block on chunk sends and never reach completion. Calling Close -// before Output is allowed but redundant; both are idempotent. -// -// In-band failures resolve rather than error: a failed turn or a -// rejected request returns an [AgentOutput] with -// [AgentFinishReasonFailed], the error on [AgentOutput.Error] (original -// status intact), and the last-good state on [AgentOutput.State] -// (client-managed) or behind [AgentOutput.SnapshotID] (server-managed), -// so a failure costs the caller only the failed turn, never the -// session. A non-nil error here means the invocation itself could not -// run to a result (e.g. the connection's context was cancelled). -// -// Output is itself idempotent: subsequent calls return the same -// (*AgentOutput, error) from cache. The returned pointer is shared -// across calls; treat it as read-only. +// side, drains any chunks the caller did not consume via Receive, and +// blocks until the agent finalizes. Calling Close first is allowed but +// redundant. Output is idempotent: subsequent calls return the same +// (*AgentOutput, error); the returned pointer is shared across calls, +// so treat it as read-only. // -// Detach: when the client sends Detach, the agent function returns -// promptly with a pending snapshot ID. Output returns that output -// rather than a context cancellation error. +// In-band failures resolve rather than error: a failed turn returns an +// [AgentOutput] with [AgentFinishReasonFailed], the error on +// [AgentOutput.Error] (original status intact), and the last-good state +// on [AgentOutput.State] (client-managed) or behind +// [AgentOutput.SnapshotID] (server-managed), so a failure costs the +// caller only the failed turn, never the session. A detached invocation +// resolves with the pending snapshot ID rather than a cancellation +// error. A non-nil error here means the invocation never started (a +// rejected init payload) or could not run to a result (e.g. the +// connection's context was cancelled). // -// Do not call Output concurrently with a goroutine iterating Receive — +// Do not call Output concurrently with a goroutine iterating Receive; // both consume from the same stream and chunks would be split between -// them. Sequence the calls: finish Receive first, then call Output. +// them. Finish Receive first, then call Output. func (c *AgentConnection[Stream, State]) Output() (*AgentOutput[State], error) { _ = c.conn.Close() diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index e275dc69d3..b9a31a9d79 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -22,6 +22,7 @@ import ( "fmt" "slices" "strings" + "sync" "testing" "time" @@ -1018,9 +1019,11 @@ func TestAgent_CustomAgentContinuesAfterFailedTurn(t *testing.T) { } } -func TestAgent_InitFailure_ResolvesFailedOutputWithStatus(t *testing.T) { - // Pre-turn precondition/validation failures resolve as failed outputs - // carrying the original error status, rather than failing the action. +func TestAgent_InitFailure_FailsActionWithStatus(t *testing.T) { + // Pre-turn precondition/validation failures fail the action outright + // (no failed-AgentOutput conversion, no snapshot): the invocation never + // reached the input phase, so there is no conversation state to hand + // back. The error keeps its original status. ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -1067,23 +1070,18 @@ func TestAgent_InitFailure_ResolvesFailedOutputWithStatus(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { out, err := tc.run() - if err != nil { - t.Fatalf("expected graceful failed output, got error: %v", err) - } - if out.FinishReason != AgentFinishReasonFailed { - t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) - } - if out.Error == nil { - t.Fatal("expected error on output") + if err == nil { + t.Fatalf("expected error, got output: %+v", out) } - if out.Error.Status != tc.wantStatus { - t.Errorf("expected status %q, got %q", tc.wantStatus, out.Error.Status) + if out != nil { + t.Errorf("expected nil output on init failure, got %+v", out) } - if !strings.Contains(out.Error.Message, tc.wantMsg) { - t.Errorf("expected error message to contain %q, got %q", tc.wantMsg, out.Error.Message) + ge := core.AsGenkitError(err) + if ge.Status != tc.wantStatus { + t.Errorf("expected status %q, got %q (err: %v)", tc.wantStatus, ge.Status, err) } - if out.SnapshotID != "" || out.State != nil { - t.Errorf("init failures carry no state, got snapshotId=%q state=%+v", out.SnapshotID, out.State) + if !strings.Contains(ge.Message, tc.wantMsg) { + t.Errorf("expected error message to contain %q, got %q", tc.wantMsg, ge.Message) } }) } @@ -2785,17 +2783,14 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { t.Errorf("expected snapshot.Error.Message to contain %q, got %+v", "kaboom", snap.Error) } - // Resuming from an errored detached snapshot is rejected; the - // rejection resolves as a failed output carrying the original error. + // Resuming from an errored detached snapshot is rejected before the + // invocation starts, so the action fails with the original error. resumeOut, err := af.RunText(context.Background(), "retry", WithSnapshotID[testState](out.SnapshotID)) - if err != nil { - t.Fatalf("RunText: %v", err) - } - if resumeOut.FinishReason != AgentFinishReasonFailed { - t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, resumeOut.FinishReason) + if err == nil { + t.Fatalf("expected error resuming errored snapshot, got output: %+v", resumeOut) } - if resumeOut.Error == nil || !strings.Contains(resumeOut.Error.Message, "kaboom") { - t.Errorf("unexpected resume error: %+v", resumeOut.Error) + if !strings.Contains(err.Error(), "kaboom") { + t.Errorf("expected resume error to surface the original failure, got: %v", err) } } @@ -3010,20 +3005,15 @@ func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { ) out, err := af.RunText(context.Background(), "hi", WithSnapshotID[testState](erroredID)) - if err != nil { - t.Fatalf("RunText: %v", err) - } - if out.FinishReason != AgentFinishReasonFailed { - t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) - } - if out.Error == nil { - t.Fatal("expected output error when resuming from errored snapshot") + if err == nil { + t.Fatalf("expected error when resuming from errored snapshot, got output: %+v", out) } - if out.Error.Status != core.FAILED_PRECONDITION { - t.Errorf("expected status %q, got %q", core.FAILED_PRECONDITION, out.Error.Status) + ge := core.AsGenkitError(err) + if ge.Status != core.FAILED_PRECONDITION { + t.Errorf("expected status %q, got %q", core.FAILED_PRECONDITION, ge.Status) } - if !strings.Contains(out.Error.Message, "underlying failure") { - t.Errorf("expected error to surface underlying failure, got: %v", out.Error) + if !strings.Contains(ge.Message, "underlying failure") { + t.Errorf("expected error to surface underlying failure, got: %v", err) } } @@ -3031,7 +3021,9 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - // Transform that scrubs a specific word from all messages. + // Transform that scrubs a specific word from all messages. It also + // (incorrectly) drops the framework-owned session ID, which the + // action must re-stamp on the way out. transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { for _, msg := range s.Messages { for _, p := range msg.Content { @@ -3040,6 +3032,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } } } + s.SessionID = "" return s } @@ -3062,7 +3055,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { // Transform is action-layer behavior: invoke the registered action // directly the way a non-Go client would. - action := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + action := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}, struct{}]( reg, api.ActionTypeAgentSnapshot, "transformedFlow") if action == nil { t.Fatal("getSnapshot action not registered") @@ -3088,6 +3081,11 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } } } + // The transform dropped the state-carried session ID; the action + // re-stamps it from the row so outbound state stays self-describing. + if resp.State.SessionID != resp.SessionID { + t.Errorf("state-carried session ID = %q, want re-stamped %q", resp.State.SessionID, resp.SessionID) + } // The stored snapshot must remain untransformed so the flow can be // resumed faithfully. @@ -3132,7 +3130,7 @@ func TestAgent_GetSnapshotAction_ReturnsFinishReason(t *testing.T) { t.Fatalf("RunText: %v", err) } - action := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + action := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}, struct{}]( reg, api.ActionTypeAgentSnapshot, "finishReasonActionFlow") if action == nil { t.Fatal("getSnapshot action not registered") @@ -3142,7 +3140,7 @@ func TestAgent_GetSnapshotAction_ReturnsFinishReason(t *testing.T) { t.Fatalf("getSnapshot action: %v", err) } if resp.FinishReason != AgentFinishReasonStop { - t.Errorf("GetSnapshotResponse.FinishReason = %q, want %q", resp.FinishReason, AgentFinishReasonStop) + t.Errorf("getSnapshot FinishReason = %q, want %q", resp.FinishReason, AgentFinishReasonStop) } } @@ -3169,7 +3167,7 @@ func TestAgent_GetSnapshotAction_NoStore(t *testing.T) { }, ) - getAction := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + getAction := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}, struct{}]( reg, api.ActionTypeAgentSnapshot, "noStoreFlow") if getAction != nil { t.Error("getSnapshot action should NOT be registered without a store") @@ -3183,14 +3181,30 @@ func TestAgent_GetSnapshotAction_NoStore(t *testing.T) { func TestLoadSession_AgentInitValidation(t *testing.T) { // loadSession enforces the AgentInit invariants: - // - snapshotId and state are mutually exclusive, - // - snapshotId requires a store (server-managed state), - // - state requires the absence of a store (client-managed state). + // - state is mutually exclusive with sessionId and snapshotId (a + // client-managed conversation's identity rides inside the state), + // - sessionId and snapshotId require a store (server-managed state), + // - state requires the absence of a store (client-managed state), + // - sessionId composes with snapshotId as an integrity assertion on + // the loaded snapshot. ctx := context.Background() store := newTestInMemStore[testState]() state := &SessionState[testState]{Custom: testState{Counter: 1}} - cases := []struct { + // A persisted snapshot belonging to session "sess-1", for the + // sessionId+snapshotId match and mismatch cases. + saved, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + SessionID: "sess-1", + Event: SnapshotEventInvocationEnd, + State: state, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + + errCases := []struct { name string init *AgentInit[testState] store SessionStore[testState] @@ -3198,7 +3212,7 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { }{ { name: "both snapshotId and state set", - init: &AgentInit[testState]{SnapshotID: "snap-1", State: state}, + init: &AgentInit[testState]{SnapshotID: saved.SnapshotID, State: state}, store: store, wantErr: "mutually exclusive", }, @@ -3208,6 +3222,18 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { store: nil, wantErr: "mutually exclusive", }, + { + name: "sessionId and state set", + init: &AgentInit[testState]{SessionID: "sess-1", State: state}, + store: nil, + wantErr: "mutually exclusive", + }, + { + name: "all three set", + init: &AgentInit[testState]{SessionID: "sess-1", SnapshotID: saved.SnapshotID, State: state}, + store: store, + wantErr: "mutually exclusive", + }, { name: "state with server-managed agent", init: &AgentInit[testState]{State: state}, @@ -3220,9 +3246,27 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { store: nil, wantErr: "client-managed state", }, + { + name: "sessionId with client-managed agent", + init: &AgentInit[testState]{SessionID: "sess-1"}, + store: nil, + wantErr: "client-managed state", + }, + { + name: "sessionId with no matching snapshots", + init: &AgentInit[testState]{SessionID: "sess-unknown"}, + store: store, + wantErr: "no resumable snapshot", + }, + { + name: "sessionId mismatching the loaded snapshot", + init: &AgentInit[testState]{SessionID: "sess-other", SnapshotID: saved.SnapshotID}, + store: store, + wantErr: "does not belong to session", + }, } - for _, tc := range cases { + for _, tc := range errCases { t.Run(tc.name, func(t *testing.T) { _, _, err := loadSession(ctx, tc.init, tc.store) if err == nil { @@ -3234,33 +3278,78 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { }) } - t.Run("empty init with server store is allowed", func(t *testing.T) { - sess, snap, err := loadSession(ctx, &AgentInit[testState]{}, store) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if sess == nil { - t.Fatal("expected session, got nil") - } - if snap != nil { - t.Errorf("expected no snapshot, got %+v", snap) - } - }) + okCases := []struct { + name string + init *AgentInit[testState] + store SessionStore[testState] + wantSnap bool + }{ + { + name: "empty init with server store", + init: &AgentInit[testState]{}, + store: store, + }, + { + name: "empty init with no store", + init: &AgentInit[testState]{}, + }, + { + name: "state carrying its session ID with client-managed agent", + init: &AgentInit[testState]{State: &SessionState[testState]{SessionID: "client-sess", Custom: testState{Counter: 1}}}, + }, + { + name: "sessionId matching the loaded snapshot", + init: &AgentInit[testState]{SessionID: "sess-1", SnapshotID: saved.SnapshotID}, + store: store, + wantSnap: true, + }, + } - t.Run("empty init with no store is allowed", func(t *testing.T) { - sess, snap, err := loadSession(ctx, &AgentInit[testState]{}, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if sess == nil { - t.Fatal("expected session, got nil") + for _, tc := range okCases { + t.Run(tc.name, func(t *testing.T) { + sess, snap, err := loadSession(ctx, tc.init, tc.store) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sess == nil { + t.Fatal("expected session, got nil") + } + if tc.wantSnap != (snap != nil) { + t.Errorf("snapshot presence = %v, want %v (snap: %+v)", snap != nil, tc.wantSnap, snap) + } + if tc.init.State != nil && sess.State().Custom.Counter != tc.init.State.Custom.Counter { + t.Errorf("state not loaded: got %+v", sess.State()) + } + }) + } + + t.Run("latest-snapshot lookup validates the store's answer", func(t *testing.T) { + // A non-conforming store that resolves a session to a snapshot from + // a different session must be caught rather than silently resuming + // the wrong conversation. + _, _, err := loadSession(ctx, &AgentInit[testState]{SessionID: "sess-lied-about"}, + wrongSessionStore[testState]{SessionStore: store, snapshotID: saved.SnapshotID}) + if err == nil { + t.Fatal("expected error, got nil") } - if snap != nil { - t.Errorf("expected no snapshot, got %+v", snap) + if !strings.Contains(err.Error(), "violates the GetLatestSnapshot contract") { + t.Errorf("error %q does not mention the contract violation", err.Error()) } }) } +// wrongSessionStore wraps a SessionStore but resolves every +// GetLatestSnapshot call to a fixed snapshot, regardless of the requested +// session: a deliberately non-conforming implementation. +type wrongSessionStore[State any] struct { + SessionStore[State] + snapshotID string +} + +func (s wrongSessionStore[State]) GetLatestSnapshot(ctx context.Context, sessionID string) (*SessionSnapshot[State], error) { + return s.SessionStore.GetSnapshot(ctx, s.snapshotID) +} + // minimalStore is a SessionStore that does NOT implement SnapshotAborter. // Used to verify the abort action stays unregistered for stores that // lack the capability. @@ -3269,6 +3358,9 @@ type minimalStore[State any] struct{} func (minimalStore[State]) GetSnapshot(context.Context, string) (*SessionSnapshot[State], error) { return nil, nil } +func (minimalStore[State]) GetLatestSnapshot(context.Context, string) (*SessionSnapshot[State], error) { + return nil, nil +} func (minimalStore[State]) SaveSnapshot( context.Context, string, func(*SessionSnapshot[State]) (*SessionSnapshot[State], error), @@ -3361,7 +3453,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { }, WithSessionStore(store), ) - getAction := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + getAction := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}, struct{}]( reg, api.ActionTypeAgentSnapshot, "fullCaps") if getAction == nil { t.Error("getSnapshot action should be registered") @@ -3381,7 +3473,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { }, WithSessionStore[testState](minimalStore[testState]{}), ) - getAction := core.ResolveActionFor[*GetSnapshotRequest, *GetSnapshotResponse[testState], struct{}, struct{}]( + getAction := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}, struct{}]( reg, api.ActionTypeAgentSnapshot, "minCaps") if getAction == nil { t.Error("getSnapshot action should be registered even when store lacks SnapshotAborter") @@ -4545,3 +4637,836 @@ func TestAgent_Detach_SucceededHonorsResultOverride(t *testing.T) { t.Errorf("finalized snapshot.FinishReason = %q, want %q (AgentResult override)", snap.FinishReason, AgentFinishReasonOther) } } + +// --- Session ID tests --- + +func TestAgent_SessionID_AssignedAndStable(t *testing.T) { + // The runtime assigns the session ID when the invocation starts; every + // snapshot the invocation persists carries it and the output reports it. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionAssignFlow", WithSessionStore(store)) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + var snapshotIDs []string + for _, text := range []string{"turn one", "turn two"} { + if err := conn.SendText(text); err != nil { + t.Fatalf("SendText: %v", err) + } + te := nextTurnEnd(t, conn) + if te.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, te.SnapshotID) + } + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SessionID == "" { + t.Fatal("expected session ID on output") + } + if len(snapshotIDs) != 2 { + t.Fatalf("expected 2 turn-end snapshots, got %d", len(snapshotIDs)) + } + + for _, id := range append(snapshotIDs, out.SnapshotID) { + snap, err := store.GetSnapshot(ctx, id) + if err != nil { + t.Fatalf("GetSnapshot(%q): %v", id, err) + } + if snap.SessionID != out.SessionID { + t.Errorf("snapshot %q SessionID = %q, want %q", id, snap.SessionID, out.SessionID) + } + // The persisted state blob is self-describing: it mirrors the + // row's session ID. + if snap.State == nil || snap.State.SessionID != out.SessionID { + t.Errorf("snapshot %q state-carried session ID = %v, want %q", id, snap.State, out.SessionID) + } + } +} + +func TestAgent_SessionID_AssignedBeforeFirstSnapshot(t *testing.T) { + // The session ID exists from invocation start, not from the first + // snapshot write: an invocation whose callback declines every write + // still reports the session it belongs to and exposes it to the agent + // fn, with no snapshot to show for it yet. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + var fnSawSessionID, ctxSawSessionID string + af := DefineCustomAgent(reg, "sessionAlwaysAssigned", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + fnSawSessionID = sess.SessionID() + // The ID lives on the session itself, so code holding only the + // context-carried session (e.g. a tool) can read it too. + if s := SessionFromContext[testState](ctx); s != nil { + ctxSawSessionID = s.SessionID() + } + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil, nil + }) + }, + WithSessionStore(store), + WithSnapshotCallback(func(context.Context, *SnapshotContext[testState]) bool { return false }), + ) + + out, err := af.RunText(ctx, "hi") + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.FinishReason == AgentFinishReasonFailed { + t.Fatalf("invocation failed: %+v", out.Error) + } + if out.SessionID == "" { + t.Fatal("expected session ID assigned at invocation start") + } + if fnSawSessionID != out.SessionID { + t.Errorf("fn saw session ID %q, output reports %q", fnSawSessionID, out.SessionID) + } + if ctxSawSessionID != out.SessionID { + t.Errorf("context-carried session saw ID %q, output reports %q", ctxSawSessionID, out.SessionID) + } + if out.SnapshotID != "" { + t.Errorf("expected no snapshot (callback declined every write), got %q", out.SnapshotID) + } + + // A session with no persisted snapshots cannot be resumed by its ID; + // the init-level rejection fails the action. + out2, err := af.RunText(ctx, "again", WithSessionID[testState](out.SessionID)) + if err == nil { + t.Fatalf("expected NOT_FOUND error for snapshot-less session, got output: %+v", out2) + } + if ge := core.AsGenkitError(err); ge.Status != core.NOT_FOUND { + t.Fatalf("expected NOT_FOUND, got %q (err: %v)", ge.Status, err) + } +} + +func TestAgent_SessionID_StableAcrossSnapshotResume(t *testing.T) { + // Resuming from a snapshot keeps extending the same session: rows + // written by the resumed invocation inherit the chain's session ID. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionStableFlow", WithSessionStore(store)) + + out1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out1.SessionID == "" { + t.Fatal("expected session ID from first invocation") + } + + out2, err := af.RunText(ctx, "second", WithSnapshotID[testState](out1.SnapshotID)) + if err != nil { + t.Fatalf("RunText resume: %v", err) + } + if out2.SessionID != out1.SessionID { + t.Errorf("resumed invocation SessionID = %q, want %q", out2.SessionID, out1.SessionID) + } + snap2, err := store.GetSnapshot(ctx, out2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap2.SessionID != out1.SessionID { + t.Errorf("resumed snapshot SessionID = %q, want %q", snap2.SessionID, out1.SessionID) + } +} + +func TestAgent_ResumeFromSessionID(t *testing.T) { + // WithSessionID resolves the session's latest snapshot and continues + // the conversation from it. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionResumeFlow", WithSessionStore(store)) + + out1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("RunText: %v", err) + } + + out2, err := af.RunText(ctx, "second", WithSessionID[testState](out1.SessionID)) + if err != nil { + t.Fatalf("RunText session resume: %v", err) + } + if out2.FinishReason == AgentFinishReasonFailed { + t.Fatalf("session resume failed: %+v", out2.Error) + } + if out2.SessionID != out1.SessionID { + t.Errorf("SessionID = %q, want %q", out2.SessionID, out1.SessionID) + } + + snap2, err := store.GetSnapshot(ctx, out2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + // Continued the conversation: both invocations' messages and both + // counter increments, chained off the first invocation's snapshot. + if got := len(snap2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after session resume, got %d", got) + } + if got := snap2.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2 after session resume, got %d", got) + } + if snap2.ParentID != out1.SnapshotID { + t.Errorf("resumed snapshot ParentID = %q, want %q", snap2.ParentID, out1.SnapshotID) + } +} + +func TestAgent_ResumeFromSessionID_ForkContinuesLatestBranch(t *testing.T) { + // Re-invoking the agent from a non-tip snapshot forks the session. + // Session-ID init does not care: the most recently updated branch + // wins, so the conversation continues where activity last happened. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionForkFlow", WithSessionStore(store)) + + out1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("RunText: %v", err) + } + // Two sibling branches off the same parent: a fork. + var branches []*AgentOutput[testState] + for _, text := range []string{"branch b", "branch c"} { + time.Sleep(2 * time.Millisecond) // order branches unambiguously by UpdatedAt + out, err := af.RunText(ctx, text, WithSnapshotID[testState](out1.SnapshotID)) + if err != nil { + t.Fatalf("RunText branch: %v", err) + } + if out.SessionID != out1.SessionID { + t.Fatalf("branch SessionID = %q, want %q", out.SessionID, out1.SessionID) + } + branches = append(branches, out) + } + + out, err := af.RunText(ctx, "which branch?", WithSessionID[testState](out1.SessionID)) + if err != nil { + t.Fatalf("RunText session resume: %v", err) + } + if out.FinishReason == AgentFinishReasonFailed { + t.Fatalf("session resume failed: %+v", out.Error) + } + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if want := branches[1].SnapshotID; snap.ParentID != want { + t.Errorf("resumed snapshot ParentID = %q, want most recent branch %q", snap.ParentID, want) + } +} + +func TestAgent_ResumeFromSessionID_SkipsDeadEnds(t *testing.T) { + // A failed (or aborted) row is a permanent dead end: even as the + // session's newest row it never blocks session-ID init. The session + // resumes from the last good snapshot. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionDeadEndFlow", WithSessionStore(store)) + + out1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("RunText: %v", err) + } + // A failed detach-style row chained off the tip, as a background + // invocation that failed would leave behind. + if _, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + SessionID: out1.SessionID, + ParentID: out1.SnapshotID, + Event: SnapshotEventDetach, + Status: SnapshotStatusFailed, + FinishReason: AgentFinishReasonFailed, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot failed row: %v", err) + } + + out2, err := af.RunText(ctx, "second", WithSessionID[testState](out1.SessionID)) + if err != nil { + t.Fatalf("RunText session resume: %v", err) + } + if out2.FinishReason == AgentFinishReasonFailed { + t.Fatalf("session resume failed: %+v", out2.Error) + } + snap2, err := store.GetSnapshot(ctx, out2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap2.ParentID != out1.SnapshotID { + t.Errorf("resumed snapshot ParentID = %q, want last good %q", snap2.ParentID, out1.SnapshotID) + } + if got := snap2.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2 (resumed from last good state), got %d", got) + } + + // The dead end keeps being skipped on subsequent resumes. + out3, err := af.RunText(ctx, "third", WithSessionID[testState](out1.SessionID)) + if err != nil { + t.Fatalf("RunText second session resume: %v", err) + } + if out3.FinishReason == AgentFinishReasonFailed { + t.Fatalf("second session resume failed: %+v", out3.Error) + } +} + +func TestAgent_ResumeFromSessionID_AfterFailureResumesRecovery(t *testing.T) { + // After an invocation fails, the session's newest non-dead-end row is + // the recovery snapshot (a normal succeeded row, event=recovery) + // holding the last-good state. Resuming by session ID continues from + // it like any other snapshot. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionRecoveryFlow", + WithSessionStore[testState](store), + // Persist only the first turn so the failure path must write a + // genuine recovery-event row (rather than reusing a turn-end row). + WithSnapshotCallback(func(_ context.Context, sc *SnapshotContext[testState]) bool { + return sc.Event == SnapshotEventTurnEnd && sc.TurnIndex == 0 + }), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + for _, text := range []string{"one", "two", "boom"} { + if err := conn.SendText(text); err != nil { + t.Fatalf("SendText(%q): %v", text, err) + } + } + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Fatalf("expected failed invocation, got %q", out.FinishReason) + } + recovery, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil || recovery == nil { + t.Fatalf("GetSnapshot(%q): %v, %v", out.SnapshotID, recovery, err) + } + if recovery.Event != SnapshotEventRecovery { + t.Fatalf("expected recovery snapshot, got event %q", recovery.Event) + } + + out2, err := af.RunText(ctx, "three", WithSessionID[testState](out.SessionID)) + if err != nil { + t.Fatalf("RunText session resume: %v", err) + } + if out2.FinishReason == AgentFinishReasonFailed { + t.Fatalf("session resume failed: %+v", out2.Error) + } + snap2, err := store.GetSnapshot(ctx, out2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap2.ParentID != recovery.SnapshotID { + t.Errorf("resumed snapshot ParentID = %q, want recovery row %q", snap2.ParentID, recovery.SnapshotID) + } + // Last-good state (two successful turns, counter=2) plus the resumed + // turn: the failed turn's partial mutations never made it in. + if got := snap2.State.Custom.Counter; got != 3 { + t.Errorf("expected counter=3 after resuming recovery state, got %d", got) + } + if got := len(snap2.State.Messages); got != 6 { + t.Errorf("expected 6 messages after resuming recovery state, got %d", got) + } +} + +func TestAgent_ResumeFromSessionID_PendingTipRejected(t *testing.T) { + // A pending tip means a detached invocation is still running; resuming + // the session would race the background writer, so it is rejected + // until the row finalizes. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionPendingFlow", WithSessionStore(store)) + + out1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("RunText: %v", err) + } + if _, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + SessionID: out1.SessionID, + ParentID: out1.SnapshotID, + Event: SnapshotEventDetach, + Status: SnapshotStatusPending, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot pending row: %v", err) + } + + out, err := af.RunText(ctx, "second", WithSessionID[testState](out1.SessionID)) + if err == nil { + t.Fatalf("expected error for pending tip, got output: %+v", out) + } + ge := core.AsGenkitError(err) + if ge.Status != core.FAILED_PRECONDITION { + t.Fatalf("expected FAILED_PRECONDITION, got %q (err: %v)", ge.Status, err) + } + if !strings.Contains(ge.Message, "still pending") { + t.Errorf("expected error message to mention pending, got %q", ge.Message) + } +} + +func TestAgent_ClientManagedState_MintsSessionID(t *testing.T) { + // With no store configured and a state object that carries no session + // ID, the framework mints one and stamps it inside the output state, + // so the client's opaque round-trip picks up a stable conversation + // identity without tracking a separate field. + ctx := context.Background() + reg := newTestRegistry(t) + af := defineLastGoodTestAgent(reg, "clientSessionFlow") + + out, err := af.RunText(ctx, "hi", WithState(&SessionState[testState]{Custom: testState{Counter: 1}})) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.SessionID == "" { + t.Fatal("expected minted SessionID for client-managed agent") + } + if out.State == nil || out.State.Custom.Counter != 2 { + t.Fatalf("expected state passthrough with counter=2, got %+v", out.State) + } + if out.State.SessionID != out.SessionID { + t.Errorf("state-carried session ID %q, want output's %q", out.State.SessionID, out.SessionID) + } +} + +func TestAgent_ClientManagedState_SessionIDRoundTrip(t *testing.T) { + // With no store, the conversation's identity rides inside the state + // object: an ID carried on SessionState is kept, the fn can read it, + // and the output echoes it both top-level and inside the state, so + // resending the state object preserves it across invocations. + ctx := context.Background() + reg := newTestRegistry(t) + + var fnSawSessionID string + af := DefineCustomAgent(reg, "clientPassthroughFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + fnSawSessionID = sess.SessionID() + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil, nil + }) + }, + ) + + out, err := af.RunText(ctx, "hi", + WithState(&SessionState[testState]{SessionID: "client-chosen-id", Custom: testState{Counter: 1}})) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.SessionID != "client-chosen-id" { + t.Errorf("output SessionID = %q, want %q", out.SessionID, "client-chosen-id") + } + if fnSawSessionID != "client-chosen-id" { + t.Errorf("fn saw session ID %q, want %q", fnSawSessionID, "client-chosen-id") + } + if out.State == nil || out.State.Custom.Counter != 2 { + t.Fatalf("expected state passthrough with counter=2, got %+v", out.State) + } + if out.State.SessionID != "client-chosen-id" { + t.Errorf("state-carried session ID = %q, want %q", out.State.SessionID, "client-chosen-id") + } + + // Resending the output state opaquely keeps the identity. + out2, err := af.RunText(ctx, "again", WithState(out.State)) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out2.SessionID != "client-chosen-id" { + t.Errorf("round-tripped SessionID = %q, want %q", out2.SessionID, "client-chosen-id") + } + if out2.State == nil || out2.State.Custom.Counter != 3 { + t.Errorf("expected continued state with counter=3, got %+v", out2.State) + } +} + +func TestAgent_ClientManagedState_WithSessionIDRejected(t *testing.T) { + // WithSessionID is a store lookup; a client-managed agent has no store + // and carries the conversation's identity inside the state object, so + // the option is rejected up front. + ctx := context.Background() + reg := newTestRegistry(t) + af := defineLastGoodTestAgent(reg, "clientNoSessionLookupFlow") + + _, err := af.RunText(ctx, "hi", WithSessionID[testState]("some-id")) + if err == nil { + t.Fatal("expected error for WithSessionID on a client-managed agent") + } + if !strings.Contains(err.Error(), "SessionState.SessionID") { + t.Errorf("expected error to point at SessionState.SessionID, got %q", err.Error()) + } +} + +func TestAgent_ResumeFromSnapshotID_WithSessionID(t *testing.T) { + // A session ID sent alongside a snapshot ID asserts which conversation + // the snapshot belongs to: a match resumes normally, a mismatch is + // rejected before the invocation starts. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "snapshotWithSessionFlow", WithSessionStore(store)) + + out1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("RunText: %v", err) + } + + out2, err := af.RunText(ctx, "second", + WithSnapshotID[testState](out1.SnapshotID), + WithSessionID[testState](out1.SessionID)) + if err != nil { + t.Fatalf("RunText resume with matching session ID: %v", err) + } + if out2.SessionID != out1.SessionID { + t.Errorf("SessionID = %q, want %q", out2.SessionID, out1.SessionID) + } + + out3, err := af.RunText(ctx, "third", + WithSnapshotID[testState](out2.SnapshotID), + WithSessionID[testState]("not-the-right-session")) + if err == nil { + t.Fatalf("expected mismatch error, got output: %+v", out3) + } + ge := core.AsGenkitError(err) + if ge.Status != core.INVALID_ARGUMENT { + t.Errorf("expected status %q, got %q (err: %v)", core.INVALID_ARGUMENT, ge.Status, err) + } + if !strings.Contains(ge.Message, "does not belong to session") { + t.Errorf("expected mismatch message, got %q", ge.Message) + } +} + +func TestAgent_Detach_AssignsSessionID(t *testing.T) { + // A detach on a fresh conversation carries the runtime-assigned session + // ID on the pending row; the detached output reports it and the + // finalized row keeps it. + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + release := make(chan struct{}) + entered := make(chan struct{}) + af := DefineCustomAgent(reg, "detachSessionFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + select { + case entered <- struct{}{}: + case <-ctx.Done(): + } + <-release + sess.AddMessages(ai.NewModelTextMessage("finished")) + return nil, nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("go"); err != nil { + t.Fatalf("SendText: %v", err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + <-entered + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SessionID == "" { + t.Fatal("expected session ID on detached output") + } + pending, err := store.GetSnapshot(context.Background(), out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if pending.SessionID != out.SessionID { + t.Errorf("pending row SessionID = %q, want %q", pending.SessionID, out.SessionID) + } + + close(release) + final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusSucceeded + }) + if final.SessionID != out.SessionID { + t.Errorf("finalized row SessionID = %q, want %q", final.SessionID, out.SessionID) + } +} + +// blockingSaveStore wraps testInMemStore so a test can hold the first +// SaveSnapshot call open: the agent's turn-end write blocks until release +// is closed, modeling a slow store at the exact moment a detach lands. +type blockingSaveStore[State any] struct { + *testInMemStore[State] + entered chan struct{} + release chan struct{} + once sync.Once +} + +func (s *blockingSaveStore[State]) SaveSnapshot( + ctx context.Context, + id string, + fn func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error), +) (*SessionSnapshot[State], error) { + blocked := false + s.once.Do(func() { blocked = true }) + if blocked { + close(s.entered) + <-s.release + } + return s.testInMemStore.SaveSnapshot(ctx, id, fn) +} + +func TestAgent_Detach_WaitsForInFlightTurnSnapshot(t *testing.T) { + // A detach that lands while a turn-end snapshot write is still in + // flight must wait for it: the pending row chains off the just-written + // row instead of becoming its sibling (which would fork the session and + // permanently break resume-by-session-ID), and the conversation stays + // in one session. + reg := newTestRegistry(t) + store := &blockingSaveStore[testState]{ + testInMemStore: newTestInMemStore[testState](), + entered: make(chan struct{}), + release: make(chan struct{}), + } + af := defineLastGoodTestAgent(reg, "detachMidWrite", WithSessionStore[testState](store)) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("one"); err != nil { + t.Fatalf("SendText: %v", err) + } + // Wait until the turn-end snapshot write is in flight, then detach + // while it is still blocked inside the store. + <-store.entered + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + close(store.release) + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonDetached { + t.Fatalf("FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonDetached) + } + + final := waitForSnapshot[testState](t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusSucceeded + }) + + // Find the turn-end row (the only row besides the detach row). + var turnRowID, turnRowSession string + others := 0 + store.testInMemStore.mu.RLock() + for _, r := range store.testInMemStore.snapshots { + if r.SnapshotID == out.SnapshotID { + continue + } + others++ + turnRowID, turnRowSession = r.SnapshotID, r.SessionID + } + store.testInMemStore.mu.RUnlock() + if others != 1 { + t.Fatalf("expected exactly one turn-end row besides the detach row, got %d", others) + } + if final.ParentID != turnRowID { + t.Errorf("detach row ParentID = %q, want turn-end row %q (a sibling row forks the session)", final.ParentID, turnRowID) + } + if final.SessionID != out.SessionID || turnRowSession != out.SessionID { + t.Errorf("conversation split across sessions: turn=%q detach=%q output=%q", turnRowSession, final.SessionID, out.SessionID) + } + // The session stays linear and resolves to the finalized detach row. + tip, err := store.GetLatestSnapshot(context.Background(), out.SessionID) + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if tip == nil || tip.SnapshotID != out.SnapshotID { + t.Errorf("session tip = %+v, want %q", tip, out.SnapshotID) + } +} + +func TestAgent_FailedTurn_OutputCarriesSessionID(t *testing.T) { + // A failed invocation still reports the session it belongs to, next to + // the last-good snapshot ID. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "failedSessionFlow", WithSessionStore(store)) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + if err := conn.SendText("turn one"); err != nil { + t.Fatalf("SendText: %v", err) + } + nextTurnEnd(t, conn) + if err := conn.SendText("boom"); err != nil && !errors.Is(err, core.ErrActionCompleted) { + t.Fatalf("SendText: %v", err) + } + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Fatalf("expected failed output, got %q", out.FinishReason) + } + if out.SessionID == "" { + t.Fatal("expected session ID on failed output") + } + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.SessionID != out.SessionID { + t.Errorf("last-good snapshot SessionID = %q, want %q", snap.SessionID, out.SessionID) + } +} + +func TestAgent_ResumeFromLegacySnapshot_MintsFreshSessionID(t *testing.T) { + // Snapshots written before session IDs existed have none; resuming one + // starts a fresh session, assigned at invocation start and stamped on + // every new row. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "legacySessionFlow", WithSessionStore(store)) + + legacy := &SessionSnapshot[testState]{ + SnapshotID: "legacy-1", + Event: SnapshotEventInvocationEnd, + Status: SnapshotStatusSucceeded, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + State: &SessionState[testState]{Custom: testState{Counter: 5}}, + } + store.mu.Lock() + store.snapshots[legacy.SnapshotID] = legacy + store.mu.Unlock() + + out, err := af.RunText(ctx, "continue", WithSnapshotID[testState]("legacy-1")) + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.FinishReason == AgentFinishReasonFailed { + t.Fatalf("resume failed: %+v", out.Error) + } + if out.SessionID == "" { + t.Fatal("expected freshly minted session ID after legacy resume") + } + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.SessionID != out.SessionID { + t.Errorf("snapshot SessionID = %q, want %q", snap.SessionID, out.SessionID) + } + if got := snap.State.Custom.Counter; got != 6 { + t.Errorf("expected counter=6 (legacy state + 1), got %d", got) + } +} + +func TestAgent_WithSessionID_OptionValidation(t *testing.T) { + // The invocation-option layer rejects empty session IDs, duplicate + // options, and conflicting state sources before the action is ever + // invoked. WithSessionID composes with either state source, so those + // combinations pass the option layer and the init-level checks take + // over from there. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionOptFlow", WithSessionStore(store)) + + if _, err := af.StreamBidi(ctx, WithState(&SessionState[testState]{}), WithSnapshotID[testState]("x")); err == nil || + !strings.Contains(err.Error(), "mutually exclusive") { + t.Errorf("WithState+WithSnapshotID: expected mutual-exclusion error, got %v", err) + } + if _, err := af.StreamBidi(ctx, WithSessionID[testState]("s"), WithSessionID[testState]("s2")); err == nil || + !strings.Contains(err.Error(), "more than once") { + t.Errorf("WithSessionID twice: expected duplicate-option error, got %v", err) + } + // An empty session ID is an explicit error, not a silent no-op: a + // pipelined AgentOutput.SessionID from a storeless invocation must not + // quietly start a fresh conversation. + if _, err := af.StreamBidi(ctx, WithSessionID[testState]("")); err == nil || + !strings.Contains(err.Error(), "session ID is empty") { + t.Errorf("WithSessionID(\"\"): expected empty-ID error, got %v", err) + } + // WithSessionID composes with WithSnapshotID: the option layer accepts + // the pair and the init-level checks (here: unknown snapshot) decide. + conn, err := af.StreamBidi(ctx, WithSessionID[testState]("s"), WithSnapshotID[testState]("x")) + if err != nil { + t.Fatalf("WithSessionID+WithSnapshotID: expected option layer to accept, got %v", err) + } + if _, err := conn.Output(); err == nil { + t.Error("expected init-level error for unknown snapshot, got nil") + } +} + +func TestAgent_GetSnapshotAction_ReturnsSessionID(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionActionFlow", WithSessionStore(store)) + + ctx := context.Background() + out, err := af.RunText(ctx, "hi") + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.SessionID == "" { + t.Fatal("expected session ID on output") + } + + action := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}, struct{}]( + reg, api.ActionTypeAgentSnapshot, "sessionActionFlow") + if action == nil { + t.Fatal("getSnapshot action not registered") + } + resp, err := action.Run(ctx, &GetSnapshotRequest{SnapshotID: out.SnapshotID}, nil) + if err != nil { + t.Fatalf("getSnapshot action: %v", err) + } + if resp.SessionID != out.SessionID { + t.Errorf("getSnapshot SessionID = %q, want %q", resp.SessionID, out.SessionID) + } +} diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 06d931ea78..24fa3f976d 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -84,21 +84,43 @@ const ( ) // AgentInit is the input for starting an agent invocation. -// Exactly one of SnapshotID or State may be set, and the choice must match -// the agent's state management: -// - Server-managed state (a session store is configured): callers must -// use SnapshotID; sending State is rejected. -// - Client-managed state (no session store): callers must use State; -// sending SnapshotID is rejected. -// Sending both fields is always rejected. Sending neither starts a fresh -// invocation with empty state. +// SessionID, SnapshotID, and State are competing conversation sources, +// and which are valid depends on the agent's state management: +// - Server-managed state (a session store is configured): callers may +// use SessionID (resume the session's latest snapshot), SnapshotID +// (resume a specific snapshot), or both (the session ID is asserted +// against the snapshot); sending State is rejected. +// - Client-managed state (no session store): callers may use State, +// which carries the conversation's identity inside it +// ([SessionState.SessionID]); sending SessionID or SnapshotID is +// rejected. +// +// Sending no fields starts a fresh invocation with empty state. type AgentInit[State any] struct { + // SessionID identifies the session (conversation) to resume. Only valid + // when the agent is server-managed (a session store is configured); + // mutually exclusive with State (a client-managed conversation carries + // its identity inside [SessionState.SessionID]). Alone, it resumes the + // session from its latest snapshot: the most recently updated one that + // is not a failed/aborted dead end. A pending latest snapshot (a + // detached invocation still running) rejects the resume rather than + // racing the background work; if the session's history was forked by + // resuming an earlier snapshot again, the most recently updated branch + // wins, and SnapshotID can pick a branch explicitly. Combined with + // SnapshotID, it asserts which session the snapshot belongs to, and a + // mismatch is rejected. + SessionID string `json:"sessionId,omitempty"` // SnapshotID loads state from a persisted snapshot. Only valid when the - // agent is server-managed (a session store is configured). Mutually - // exclusive with State. + // agent is server-managed (a session store is configured). May be + // combined with SessionID to validate that the snapshot belongs to that + // session. Mutually exclusive with State. SnapshotID string `json:"snapshotId,omitempty"` // State provides direct state for the invocation. Only valid when the - // agent is client-managed (no session store). Mutually exclusive with + // agent is client-managed (no session store). The conversation's + // identity rides inside it ([SessionState.SessionID]): the framework + // mints one on the conversation's first invocation and echoes it on the + // output state, so resending the state object keeps the identity without + // tracking a separate field. Mutually exclusive with SessionID and // SnapshotID. State *SessionState[State] `json:"state,omitempty"` } @@ -110,8 +132,8 @@ type AgentInput struct { // state), returns [AgentOutput] with that snapshot ID, and continues // processing any already-buffered inputs in a background context. The // pending snapshot is finalized with the cumulative final state once all - // queued inputs are processed (or the snapshot is cancelled via - // cancelSnapshot). + // queued inputs are processed (or the invocation is aborted via the + // abortSnapshot companion action). Detach bool `json:"detach,omitempty"` // Message is the user's input for this turn. Message *ai.Message `json:"message,omitempty"` @@ -163,11 +185,23 @@ type AgentOutput[State any] struct { FinishReason AgentFinishReason `json:"finishReason,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` - // SnapshotID is the ID of the snapshot created at the end of this invocation. - // Empty if no snapshot was created (callback returned false or no store configured). - // When FinishReason is [AgentFinishReasonFailed] (and a store is configured), - // it is the most recent snapshot capturing the last-good state: everything - // through the last successful turn (see [SnapshotEventRecovery]). + // SessionID is the ID of the session this invocation belongs to, + // assigned by the framework when the invocation starts. With + // server-managed state, a fresh invocation mints a new ID, resumed + // invocations inherit the chain's, and resuming a snapshot from before + // session IDs existed mints a fresh one. With client-managed state it + // echoes the ID carried inside the state object + // ([SessionState.SessionID]), minting one on the conversation's first + // invocation; only a session with persisted snapshots can be resumed by + // this ID. + SessionID string `json:"sessionId,omitempty"` + // SnapshotID is the ID of the newest snapshot capturing this invocation: + // the invocation-end snapshot, or the latest earlier snapshot when that + // write was skipped. Empty when no store is configured or the invocation + // persisted nothing. When FinishReason is [AgentFinishReasonDetached] it + // is the pending detach snapshot; when [AgentFinishReasonFailed], the most + // recent snapshot capturing the last-good state: everything through the + // last successful turn (see [SnapshotEventRecovery]). SnapshotID string `json:"snapshotId,omitempty"` // State contains the final conversation state. // Only populated when state is client-managed (no store configured). @@ -235,45 +269,16 @@ type Artifact struct { } // GetSnapshotRequest is the input for an agent's getSnapshot companion -// action. The action is registered at `{agentName}/getSnapshot` when the -// agent is defined and is intended for Dev UI and client-side reconnect -// flows. +// action, registered under the agent's name (action type agent-snapshot) +// when the agent has a session store configured. The action is intended +// for Dev UI and client-side reconnect flows. It returns the stored +// [SessionSnapshot], with [WithStateTransform] applied to its state if +// configured. type GetSnapshotRequest struct { // SnapshotID identifies the snapshot to fetch. SnapshotID string `json:"snapshotId"` } -// GetSnapshotResponse is the output of the getSnapshot companion action. It -// is a client-facing view of the stored snapshot: identifying metadata plus -// the session state, with [WithStateTransform] applied if configured. -// -// Unlike the raw [SessionSnapshot], this response intentionally omits -// internal fields (parent ID, event) and does not leak the snapshot -// envelope beyond what callers need to repopulate a UI. -type GetSnapshotResponse[State any] struct { - // CreatedAt is when the snapshot record was first written. - CreatedAt time.Time `json:"createdAt,omitempty"` - // Error is the structured failure information; populated when Status - // is [SnapshotStatusFailed]. - Error *core.GenkitError `json:"error,omitempty"` - // FinishReason is the semantic reason the captured turn or invocation ended - // (e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], - // [AgentFinishReasonFailed], [AgentFinishReasonAborted]). It lets a remote or - // background poller report how a detached or resumed invocation ended without - // re-deriving it. Empty on a pending snapshot. - FinishReason AgentFinishReason `json:"finishReason,omitempty"` - // SnapshotID echoes the requested snapshot ID. - SnapshotID string `json:"snapshotId"` - // State is the session state captured by the snapshot, after any - // configured transform. Empty when Status is pending or error. - State *SessionState[State] `json:"state,omitempty"` - // Status is the lifecycle state of the snapshot. See [SnapshotStatus]. - Status SnapshotStatus `json:"status,omitempty"` - // UpdatedAt is when the snapshot record was last written. Equals - // CreatedAt for snapshots that have not been rewritten. - UpdatedAt time.Time `json:"updatedAt,omitempty"` -} - // SessionSnapshot is a persisted point-in-time capture of session state. It // is the canonical record written to and read from a [SessionStore]. type SessionSnapshot[State any] struct { @@ -291,8 +296,17 @@ type SessionSnapshot[State any] struct { // background task can report how it ended without re-deriving it from the // messages. FinishReason AgentFinishReason `json:"finishReason,omitempty"` - // ParentID is the ID of the previous snapshot in this timeline. + // ParentID is the ID of the previous snapshot in this timeline. It is + // informational lineage (for debugging and UI history trees) and plays + // no part in resolving a session's latest snapshot. ParentID string `json:"parentId,omitempty"` + // SessionID is the ID of the session this snapshot belongs to. Assigned + // by the agent framework when the conversation's first invocation starts + // and stamped on every later snapshot in the chain, including across + // resumed invocations. Stores preserve it across rewrites; rows written + // without one (data from before session IDs existed) belong to no + // session. + SessionID string `json:"sessionId,omitempty"` // SnapshotID is the unique identifier for this snapshot (UUID). SnapshotID string `json:"snapshotId"` // State is the conversation state captured at this point. Nil on a @@ -319,24 +333,32 @@ type SessionState[State any] struct { // Messages is the conversation history (user/model exchanges). // Does NOT include prompt-rendered messages — those are rendered fresh each turn. Messages []*ai.Message `json:"messages,omitempty"` + // SessionID is the ID of the session (conversation) this state belongs to. + // Framework-owned: assigned when the conversation's first invocation + // starts and re-stamped on outbound state, so client-managed callers can + // round-trip the state object opaquely without tracking a separate + // identifier. For server-managed agents the snapshot row's + // [SessionSnapshot.SessionID] is canonical and this field mirrors it. + SessionID string `json:"sessionId,omitempty"` } // SnapshotEvent identifies what triggered a snapshot. type SnapshotEvent string const ( - // TurnEnd indicates the snapshot was triggered at the end of a turn. + // SnapshotEventTurnEnd indicates the snapshot was triggered at the end of a turn. SnapshotEventTurnEnd SnapshotEvent = "turnEnd" - // InvocationEnd indicates the snapshot was triggered at the end of the invocation. + // SnapshotEventInvocationEnd indicates the snapshot was triggered at the end + // of the invocation. SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" - // Detach indicates the snapshot was created when the client detached the - // invocation and the flow continues in the background. The snapshot is - // initially written with [SnapshotStatusPending] and rewritten with a - // terminal status once the background work finishes. + // SnapshotEventDetach indicates the snapshot was created when the client + // detached the invocation and the work continues in the background. The + // snapshot is initially written with [SnapshotStatusPending] and rewritten + // with a terminal status once the background work finishes. SnapshotEventDetach SnapshotEvent = "detach" - // Recovery indicates the snapshot was written retroactively by the failure - // path to preserve the last-good state (everything through the last - // successful turn) when a selective snapshot callback had skipped + // SnapshotEventRecovery indicates the snapshot was written retroactively + // by the failure path to preserve the last-good state (everything through + // the last successful turn) when a selective snapshot callback had skipped // persisting it. It is a normal [SnapshotStatusSucceeded] row carrying the // last good turn's finish reason, resumable like any other; the snapshot // callback is bypassed and never sees this event. @@ -360,12 +382,12 @@ type SnapshotStatus string const ( // SnapshotStatusPending indicates a detached invocation is still // processing the queued inputs. The snapshot will be rewritten with a - // terminal status once the flow exits. + // terminal status once the background work finishes. SnapshotStatusPending SnapshotStatus = "pending" // SnapshotStatusSucceeded indicates the snapshot captures a settled state. SnapshotStatusSucceeded SnapshotStatus = "succeeded" - // SnapshotStatusAborted indicates the snapshot's invocation was - // aborted via the abortSnapshot companion action while detached. + // SnapshotStatusAborted indicates the snapshot's invocation was aborted + // while detached (e.g. via the abortSnapshot companion action). SnapshotStatusAborted SnapshotStatus = "aborted" // SnapshotStatusFailed indicates the invocation terminated with an error. // The snapshot's Error field describes the failure and resume is @@ -388,7 +410,8 @@ type TurnEnd struct { // sends fail with [core.ErrActionCompleted]. FinishReason AgentFinishReason `json:"finishReason,omitempty"` // SnapshotID is the ID of the snapshot persisted at the end of this turn. - // Empty if no snapshot was created (callback returned false or no store - // configured, or snapshots were suspended after detach). + // Empty if no snapshot was written (no store configured, the callback + // declined, nothing changed since the last snapshot, or snapshots were + // suspended after detach). SnapshotID string `json:"snapshotId,omitempty"` } diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index 64e5587574..19eb303183 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -119,6 +119,9 @@ func (s *FileSessionStore[State]) SaveSnapshot( now := time.Now() if existing != nil { next.CreatedAt = existing.CreatedAt + if existing.SessionID != "" { + next.SessionID = existing.SessionID // a row's session never changes + } } else { next.CreatedAt = now } @@ -136,31 +139,69 @@ func (s *FileSessionStore[State]) SaveSnapshot( return next, nil } -// LatestSnapshot returns the snapshot whose backing file has the most -// recent on-disk modification time, or nil if the directory has no -// snapshots yet. -// -// Selecting by file mtime (rather than parsing every file to compare -// [SessionSnapshot.UpdatedAt]) makes the operation O(N) stats plus a -// single read of the winner in the common case, rather than O(N) -// reads + JSON parses. For snapshots written by this package mtime -// and UpdatedAt advance together (each save creates a fresh temp -// file and renames it into place), so the result is identical to a -// sort by UpdatedAt. If a snapshot file is touched externally, mtime -// wins. -// -// Files that fail to stat (e.g. removed between the directory read -// and the stat) or fail to parse are skipped silently and the scan -// falls back to the next-newest candidate, so a single corrupted -// file does not poison the result. The directory listing is not -// atomic with respect to concurrent writes — a snapshot that appears -// or disappears mid-scan may or may not be observed. +// snapshotHeader is the subset of snapshot fields needed to decide +// whether a row resolves a session resume. Decoding only these avoids +// materializing every row's full conversation state during the scan. +type snapshotHeader struct { + SessionID string `json:"sessionId"` + Status exp.SnapshotStatus `json:"status"` +} + +// GetLatestSnapshot returns the session's most recently updated snapshot +// that is not a failed/aborted dead end, per the +// [exp.SnapshotReader.GetLatestSnapshot] contract. // -// This is not part of the [exp.SessionStore] interface; it is a -// FileSessionStore-specific convenience for UIs and CLIs that need to -// surface "where did I leave off" without indexing the directory -// themselves. -func (s *FileSessionStore[State]) LatestSnapshot(ctx context.Context) (*exp.SessionSnapshot[State], error) { +// Recency is judged by file mtime, which for snapshots written by this +// package advances with [exp.SessionSnapshot.UpdatedAt] (each save +// creates a fresh temp file and renames it into place); if a file is +// touched externally, mtime wins. The scan walks files newest first and +// stops at the first row that matches, so resolving the most recently +// active session costs one read in the common case. Only header fields +// are decoded per candidate (the winner is the only full parse), the +// store lock is held per file rather than across the whole scan, and a +// file that vanishes mid-scan or fails to parse is skipped so one +// corrupted row cannot poison every session in the directory. +func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { + if sessionID == "" { + return nil, errors.New("FileSessionStore: session ID is empty") + } + names, err := s.snapshotFilesNewestFirst() + if err != nil { + return nil, err + } + for _, name := range names { + s.mu.Lock() + data, err := os.ReadFile(filepath.Join(s.dir, name)) + s.mu.Unlock() + if err != nil { + continue + } + var h snapshotHeader + if err := json.Unmarshal(data, &h); err != nil { + continue + } + if h.SessionID != sessionID || + h.Status == exp.SnapshotStatusFailed || h.Status == exp.SnapshotStatusAborted { + continue + } + var snap exp.SessionSnapshot[State] + if err := json.Unmarshal(data, &snap); err != nil { + continue + } + return &snap, nil + } + return nil, nil +} + +// snapshotFilesNewestFirst returns the names of the directory's snapshot +// files (non-directory *.json entries; writeLocked's ".*.tmp" temp +// files never match) sorted by modification time, newest first, with +// name as a deterministic tie-break. Entries that vanish between the +// directory read and the stat are skipped. Returns nil if the directory +// does not exist. The listing is not atomic with respect to concurrent +// writes; a snapshot that appears or disappears mid-scan may or may not +// be observed. +func (s *FileSessionStore[State]) snapshotFilesNewestFirst() ([]string, error) { entries, err := os.ReadDir(s.dir) if err != nil { if errors.Is(err, os.ErrNotExist) { @@ -168,7 +209,6 @@ func (s *FileSessionStore[State]) LatestSnapshot(ctx context.Context) (*exp.Sess } return nil, fmt.Errorf("FileSessionStore: list dir: %w", err) } - type candidate struct { name string modTime time.Time @@ -180,18 +220,41 @@ func (s *FileSessionStore[State]) LatestSnapshot(ctx context.Context) (*exp.Sess } info, err := e.Info() if err != nil { - // Entry vanished between ReadDir and Info; ignore and keep - // scanning. Any caller-visible inconsistency is bounded to - // "a snapshot disappeared mid-scan" which the doc allows. continue } cands = append(cands, candidate{e.Name(), info.ModTime()}) } slices.SortFunc(cands, func(a, b candidate) int { - return b.modTime.Compare(a.modTime) // newest first + if c := b.modTime.Compare(a.modTime); c != 0 { // newest first + return c + } + return strings.Compare(b.name, a.name) }) - for _, c := range cands { - snap, err := s.GetSnapshot(ctx, strings.TrimSuffix(c.name, ".json")) + names := make([]string, len(cands)) + for i, c := range cands { + names[i] = c.name + } + return names, nil +} + +// LatestSnapshot returns the snapshot whose backing file has the most +// recent on-disk modification time, or nil if the directory has no +// snapshots yet. It is not part of the [exp.SessionStore] interface; it +// is a convenience for UIs and CLIs that need to surface "where did I +// leave off" without indexing the directory themselves. +// +// Selecting by mtime avoids parsing every file: for snapshots written by +// this package, mtime and [exp.SessionSnapshot.UpdatedAt] advance +// together, so the result matches a sort by UpdatedAt; if a file is +// touched externally, mtime wins. Files that fail to stat or parse are +// skipped and the scan falls back to the next-newest candidate. +func (s *FileSessionStore[State]) LatestSnapshot(ctx context.Context) (*exp.SessionSnapshot[State], error) { + names, err := s.snapshotFilesNewestFirst() + if err != nil { + return nil, err + } + for _, name := range names { + snap, err := s.GetSnapshot(ctx, strings.TrimSuffix(name, ".json")) if err != nil || snap == nil { continue } diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index a5b5cd02cb..b864cd67ab 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -513,3 +513,46 @@ func TestFileSessionStore_FinishReasonPersistsAcrossReopen(t *testing.T) { t.Errorf("FinishReason = %q, want %q", got.FinishReason, exp.AgentFinishReasonInterrupted) } } + +func TestFileSessionStore_SessionIDs(t *testing.T) { + runSessionIDStoreTests(t, func(t *testing.T) exp.SessionStore[testState] { + store, err := NewFileSessionStore[testState](t.TempDir()) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + return store + }) +} + +func TestFileSessionStore_GetLatestSnapshot_SkipsUnparseableFiles(t *testing.T) { + // A stray unparseable .json file (crash artifact, partial copy, + // hand-edited row) must not poison session resolution for the healthy + // rows in the directory; it is skipped like the dead end it is. + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + ctx := context.Background() + if _, err := store.SaveSnapshot(ctx, "a", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + SessionID: "sess-1", + Event: exp.SnapshotEventTurnEnd, + Status: exp.SnapshotStatusSucceeded, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "junk.json"), []byte("{not json"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + tip, err := store.GetLatestSnapshot(ctx, "sess-1") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if tip == nil || tip.SnapshotID != "a" { + t.Errorf("expected the healthy row as tip, got %+v", tip) + } +} diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go index 7718abd607..100eda5a70 100644 --- a/go/ai/exp/localstore/inmemory.go +++ b/go/ai/exp/localstore/inmemory.go @@ -23,6 +23,7 @@ package localstore import ( "context" "encoding/json" + "errors" "fmt" "slices" "sync" @@ -65,8 +66,37 @@ func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID return copySnapshot(snap) } +// GetLatestSnapshot returns the session's most recently updated snapshot +// that is not a failed/aborted dead end, per the +// [exp.SnapshotReader.GetLatestSnapshot] contract. Ties on UpdatedAt are +// broken by SnapshotID so resolution is deterministic. The scan runs +// under the read lock, so the stored rows (which other calls mutate in +// place) never escape it; the winner is returned as a deep copy. +func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { + if sessionID == "" { + return nil, errors.New("InMemorySessionStore: session ID is empty") + } + s.mu.RLock() + defer s.mu.RUnlock() + var latest *exp.SessionSnapshot[State] + for _, snap := range s.snapshots { + if snap.SessionID != sessionID || + snap.Status == exp.SnapshotStatusFailed || snap.Status == exp.SnapshotStatusAborted { + continue + } + if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) || + (snap.UpdatedAt.Equal(latest.UpdatedAt) && snap.SnapshotID > latest.SnapshotID) { + latest = snap + } + } + if latest == nil { + return nil, nil + } + return copySnapshot(latest) +} + // LatestSnapshot returns the snapshot with the most recent -// [SessionSnapshot.UpdatedAt] in the store, or nil if there are none. +// [exp.SessionSnapshot.UpdatedAt] in the store, or nil if there are none. // // This is not part of the [exp.SessionStore] interface; it is an // InMemorySessionStore-specific convenience that mirrors @@ -142,6 +172,9 @@ func (s *InMemorySessionStore[State]) SaveSnapshot( now := time.Now() if existing != nil { next.CreatedAt = existing.CreatedAt + if existing.SessionID != "" { + next.SessionID = existing.SessionID // a row's session never changes + } } else { next.CreatedAt = now } diff --git a/go/ai/exp/localstore/inmemory_test.go b/go/ai/exp/localstore/inmemory_test.go index 312eaa2d7f..5bdc4604b1 100644 --- a/go/ai/exp/localstore/inmemory_test.go +++ b/go/ai/exp/localstore/inmemory_test.go @@ -210,3 +210,39 @@ func TestInMemorySessionStore(t *testing.T) { } }) } + +func TestInMemorySessionStore_SessionIDs(t *testing.T) { + runSessionIDStoreTests(t, func(t *testing.T) exp.SessionStore[testState] { + return NewInMemorySessionStore[testState]() + }) + + t.Run("GetLatestSnapshotReturnsCopy", func(t *testing.T) { + store := NewInMemorySessionStore[testState]() + if _, err := store.SaveSnapshot(context.Background(), "a", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + SessionID: "sess-1", + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + first, err := store.GetLatestSnapshot(context.Background(), "sess-1") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if first == nil || first.State == nil { + t.Fatalf("expected full tip row, got %+v", first) + } + // Mutating the returned row must not leak into the store. + first.State.Custom.Counter = 999 + second, err := store.GetLatestSnapshot(context.Background(), "sess-1") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if second == nil || second.State == nil || second.State.Custom.Counter != 1 { + t.Errorf("expected isolated copy with counter=1, got %+v", second) + } + }) +} diff --git a/go/ai/exp/localstore/store_test.go b/go/ai/exp/localstore/store_test.go index c009e3be8f..bc8981cddf 100644 --- a/go/ai/exp/localstore/store_test.go +++ b/go/ai/exp/localstore/store_test.go @@ -16,8 +16,216 @@ package localstore +import ( + "context" + "testing" + "time" + + "github.com/firebase/genkit/go/ai/exp" +) + // testState is the custom-state type used by store unit tests. type testState struct { Counter int `json:"counter"` Topics []string `json:"topics,omitempty"` } + +// runSessionIDStoreTests exercises the store-owned SessionID settle rules +// and the GetLatestSnapshot contract against any [exp.SessionStore] +// implementation. Both the in-memory and file store test files invoke it +// so the two stores stay behaviorally aligned. +func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.SessionStore[testState]) { + ctx := context.Background() + + saveRow := func(t *testing.T, store exp.SessionStore[testState], id, sessionID, parentID string, status exp.SnapshotStatus) *exp.SessionSnapshot[testState] { + t.Helper() + saved, err := store.SaveSnapshot(ctx, id, + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + SessionID: sessionID, + ParentID: parentID, + Event: exp.SnapshotEventTurnEnd, + Status: status, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot(%q): %v", id, err) + } + return saved + } + + // tick spaces consecutive writes far enough apart that UpdatedAt (and + // the file store's mtimes) order them unambiguously even on coarse + // clocks. + tick := func() { time.Sleep(2 * time.Millisecond) } + + t.Run("SessionIDKeptWhenProvided", func(t *testing.T) { + store := newStore(t) + saved := saveRow(t, store, "a", "sess-keep", "", exp.SnapshotStatusSucceeded) + if saved.SessionID != "sess-keep" { + t.Errorf("SessionID = %q, want provided %q", saved.SessionID, "sess-keep") + } + stored, err := store.GetSnapshot(ctx, "a") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if stored.SessionID != "sess-keep" { + t.Errorf("stored row SessionID = %q, want %q", stored.SessionID, "sess-keep") + } + }) + + t.Run("SessionIDPreservedOnUpdate", func(t *testing.T) { + store := newStore(t) + saveRow(t, store, "a", "sess-orig", "", exp.SnapshotStatusPending) + // Finalize-style rewrite that omits (or even contradicts) the + // session ID: the existing row's ID wins, a row's session never + // changes. + for _, rewrite := range []string{"", "sess-other"} { + updated, err := store.SaveSnapshot(ctx, "a", + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + if existing == nil { + t.Fatal("expected existing row on update") + } + return &exp.SessionSnapshot[testState]{ + SessionID: rewrite, + ParentID: existing.ParentID, + Event: existing.Event, + Status: exp.SnapshotStatusSucceeded, + State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot update: %v", err) + } + if updated.SessionID != "sess-orig" { + t.Errorf("updated SessionID = %q, want preserved %q", updated.SessionID, "sess-orig") + } + } + }) + + t.Run("NoSessionWithoutID", func(t *testing.T) { + // Stores never mint or infer session IDs (the agent runtime assigns + // them at invocation start and stamps every row it writes); a row + // written without one is session-less even when its parent has one. + store := newStore(t) + saveRow(t, store, "parent", "sess-1", "", exp.SnapshotStatusSucceeded) + child := saveRow(t, store, "child", "", "parent", exp.SnapshotStatusSucceeded) + if child.SessionID != "" { + t.Errorf("expected session-less row, got SessionID %q", child.SessionID) + } + }) + + t.Run("GetLatestSnapshotPicksMostRecent", func(t *testing.T) { + // IDs deliberately sort against write order so a recency bug (or an + // accidental reliance on the tie-break) cannot pass by luck. + store := newStore(t) + saveRow(t, store, "z", "sess-1", "", exp.SnapshotStatusSucceeded) + tick() + saveRow(t, store, "m", "sess-1", "z", exp.SnapshotStatusSucceeded) + tick() + saveRow(t, store, "a", "sess-1", "m", exp.SnapshotStatusSucceeded) + tick() + saveRow(t, store, "x", "sess-other", "", exp.SnapshotStatusSucceeded) + + latest, err := store.GetLatestSnapshot(ctx, "sess-1") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "a" { + t.Fatalf("latest = %+v, want most recently written snapshot a", latest) + } + // The contract returns the full row, not a header: the runtime + // loads its state to resume. + if latest.State == nil || latest.State.Custom.Counter != 1 { + t.Errorf("latest state = %+v, want full row with counter=1", latest.State) + } + }) + + t.Run("GetLatestSnapshotUpdateWins", func(t *testing.T) { + // Recency is judged by UpdatedAt, not creation order: rewriting a + // row (e.g. a detach finalize landing after other branches were + // written) moves it to the front. + store := newStore(t) + saveRow(t, store, "root", "sess-1", "", exp.SnapshotStatusSucceeded) + tick() + saveRow(t, store, "b1", "sess-1", "root", exp.SnapshotStatusPending) + tick() + saveRow(t, store, "b2", "sess-1", "root", exp.SnapshotStatusSucceeded) + tick() + if _, err := store.SaveSnapshot(ctx, "b1", + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + rewritten := *existing + rewritten.Status = exp.SnapshotStatusSucceeded + rewritten.State = &exp.SessionState[testState]{Custom: testState{Counter: 2}} + return &rewritten, nil + }); err != nil { + t.Fatalf("SaveSnapshot finalize: %v", err) + } + + latest, err := store.GetLatestSnapshot(ctx, "sess-1") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "b1" { + t.Errorf("latest = %+v, want freshly finalized snapshot b1", latest) + } + }) + + t.Run("GetLatestSnapshotSkipsDeadEnds", func(t *testing.T) { + // Failed and aborted rows are dead ends: even when newest, they + // never hide the session's last good snapshot. + store := newStore(t) + saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusSucceeded) + tick() + saveRow(t, store, "b", "sess-1", "a", exp.SnapshotStatusFailed) + tick() + saveRow(t, store, "c", "sess-1", "a", exp.SnapshotStatusAborted) + + latest, err := store.GetLatestSnapshot(ctx, "sess-1") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "a" { + t.Errorf("latest = %+v, want last good snapshot a", latest) + } + }) + + t.Run("GetLatestSnapshotPendingReturned", func(t *testing.T) { + // A pending row is not skipped: it marks a detached invocation + // that is still running, and the runtime needs to see it to + // reject the resume instead of silently racing the background + // work. + store := newStore(t) + saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusSucceeded) + tick() + saveRow(t, store, "b", "sess-1", "a", exp.SnapshotStatusPending) + + latest, err := store.GetLatestSnapshot(ctx, "sess-1") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "b" || latest.Status != exp.SnapshotStatusPending { + t.Errorf("latest = %+v, want pending snapshot b", latest) + } + }) + + t.Run("GetLatestSnapshotUnknownSession", func(t *testing.T) { + store := newStore(t) + saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusSucceeded) + latest, err := store.GetLatestSnapshot(ctx, "sess-unknown") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if latest != nil { + t.Errorf("expected nil for unknown session, got %+v", latest) + } + }) + + t.Run("GetLatestSnapshotEmptySessionID", func(t *testing.T) { + store := newStore(t) + if _, err := store.GetLatestSnapshot(ctx, ""); err == nil { + t.Error("expected error for empty session ID") + } + }) +} diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 1cecbcb39f..1fe4958e61 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -32,7 +32,7 @@ type AgentOption[State any] interface { // StateTransform rewrites session state on its way out to a client. It // is applied to the State returned by the getSnapshot companion action -// and to [AgentResult.State] when state is client-managed (no store). +// and to [AgentOutput.State] when state is client-managed (no store). // It is not applied to state persisted in the store or to state passed // to the user agent function. // @@ -104,7 +104,7 @@ func WithSnapshotOn[State any](events ...SnapshotEvent) AgentOption[State] { // WithStateTransform registers a transform applied to session state on // its way out to a client via the getSnapshot companion action or via -// [AgentResult.State] when state is client-managed. Typical use is PII +// [AgentOutput.State] when state is client-managed. Typical use is PII // redaction or stripping secrets. The transform is not applied to state // persisted in the store or to state passed to the user agent function. func WithStateTransform[State any](transform StateTransform[State]) AgentOption[State] { @@ -121,38 +121,91 @@ type InvocationOption[State any] interface { type invocationOptions[State any] struct { state *SessionState[State] snapshotID string + sessionID string + // sessionIDSet records that WithSessionID was used, independent of the + // value: an empty session ID is rejected rather than silently ignored, + // since it usually means an [AgentOutput.SessionID] from a client-managed + // (storeless) invocation was piped through, and treating it as "no + // option" would silently start a fresh conversation. + sessionIDSet bool } +// applyInvocation merges o into opts, rejecting duplicate options. +// Mutual exclusivity (WithState versus WithSessionID/WithSnapshotID) is +// checked once, after all options are applied, in +// [Agent.resolveOptions]. func (o *invocationOptions[State]) applyInvocation(opts *invocationOptions[State]) error { if o.state != nil { if opts.state != nil { return errors.New("cannot set state more than once (WithState)") } - if opts.snapshotID != "" { - return errors.New("WithState and WithSnapshotID are mutually exclusive") - } opts.state = o.state } if o.snapshotID != "" { if opts.snapshotID != "" { return errors.New("cannot set snapshot ID more than once (WithSnapshotID)") } - if opts.state != nil { - return errors.New("WithSnapshotID and WithState are mutually exclusive") - } opts.snapshotID = o.snapshotID } + if o.sessionIDSet { + if o.sessionID == "" { + return errors.New("session ID is empty (WithSessionID); an empty AgentOutput.SessionID means the invocation had no session, check before resuming") + } + if opts.sessionIDSet { + return errors.New("cannot set session ID more than once (WithSessionID)") + } + opts.sessionID = o.sessionID + opts.sessionIDSet = true + } return nil } // WithState sets the initial state for the invocation. // Use this for client-managed state where the client sends state directly. +// The conversation's identity rides inside the state object +// ([SessionState.SessionID]): the framework mints one on the +// conversation's first invocation and echoes it on the output state, so +// resending the state keeps the identity without tracking a separate +// field. Mutually exclusive with [WithSessionID] and [WithSnapshotID]. func WithState[State any](state *SessionState[State]) InvocationOption[State] { return &invocationOptions[State]{state: state} } // WithSnapshotID loads state from a persisted snapshot by ID. // Use this for server-managed state where snapshots are stored. +// Combine with [WithSessionID] to assert which session the snapshot +// belongs to; a mismatch is rejected. Mutually exclusive with +// [WithState]. func WithSnapshotID[State any](id string) InvocationOption[State] { return &invocationOptions[State]{snapshotID: id} } + +// WithSessionID resumes the session (conversation) with the given ID +// from its latest snapshot: the most recently updated one that is not a +// failed/aborted dead end (see [SnapshotReader.GetLatestSnapshot]). Use +// this when the caller tracks the conversation rather than individual +// snapshots; the session ID is assigned when the conversation's first +// invocation starts (see [AgentOutput.SessionID]) and stays stable +// across resumed invocations. +// +// Only valid when the agent is server-managed (a session store is +// configured) and therefore mutually exclusive with [WithState]: a +// client-managed conversation carries its identity inside the state +// object ([SessionState.SessionID]) instead. Combined with +// [WithSnapshotID], the snapshot picks the exact resume point and the +// session ID is validated against it, so an invocation never silently +// continues a conversation other than the one named. +// +// A pending latest snapshot means a detached invocation is still +// running; the resume is rejected so it cannot race the background +// work. Wait for the snapshot to finalize, or abort it. If the +// session's history was forked (an earlier snapshot was resumed again, +// or two invocations resumed the session concurrently), the most +// recently updated branch wins; use [WithSnapshotID] to continue a +// specific branch instead. +// +// Passing an empty ID is an error rather than a no-op, so an unset +// [AgentOutput.SessionID] cannot silently start a fresh conversation. +func WithSessionID[State any](id string) InvocationOption[State] { + return &invocationOptions[State]{sessionID: id, sessionIDSet: true} +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 21a66d2dd6..51d7390ba9 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -62,6 +62,30 @@ func applyTransform[State any](ctx context.Context, t StateTransform[State], sta type SnapshotReader[State any] interface { // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) + + // GetLatestSnapshot resolves the snapshot a session would resume from: + // the session's most recently updated row that is not a dead end, as a + // full row (the runtime loads its state to resume). Returns nil if the + // session has no such row (unknown session ID, or every row is a dead + // end), and an error if sessionID is empty. + // + // "Most recently updated" means the greatest + // [SessionSnapshot.UpdatedAt], falling back to CreatedAt on rows that + // lack one; ties may be broken arbitrarily but deterministically (e.g. + // by SnapshotID). Failed and aborted rows are dead ends (resuming from + // them is rejected), so they are skipped and a branch that died never + // hides the session's last good snapshot. A pending row is NOT + // skipped: it marks a detached invocation that is still running, and + // surfacing it lets the agent runtime reject the resume instead of + // silently racing the background work. + // + // The contract is a status-filtered max-timestamp lookup, so stores + // can implement it as a single indexed query (e.g. WHERE sessionId = ? + // AND status NOT IN ('failed', 'aborted') ORDER BY updatedAt DESC + // LIMIT 1). ParentID is informational lineage and plays no part in + // resolution: when a session's history was forked by re-resuming an + // earlier snapshot, the most recently updated branch simply wins. + GetLatestSnapshot(ctx context.Context, sessionID string) (*SessionSnapshot[State], error) } // SnapshotWriter persists snapshots. The minimum any session store must @@ -74,6 +98,12 @@ type SnapshotWriter[State any] interface { // - SnapshotID: if id is empty, the store generates a fresh ID; // otherwise the store uses id (any SnapshotID populated by fn is // overridden). + // - SessionID: the ID of the session (chain of snapshots) the row + // belongs to: preserved from the existing row on update (a row's + // session never changes once set), otherwise taken from fn's row + // as-is. Stores never mint or infer session IDs; the agent + // runtime assigns one when an invocation starts and stamps it on + // every row it writes. // - CreatedAt: stamped to the wall clock on first write; preserved // from the existing row on update. // - UpdatedAt: stamped to the wall clock on every commit. @@ -146,7 +176,7 @@ type SnapshotAborter interface { // SessionStore is the minimum store interface required by // [WithSessionStore]. The abort lifecycle is layered as the optional // [SnapshotAborter] capability and checked at runtime: a store wired -// into a flow that intends to support detach must also implement +// into an agent that intends to support detach must also implement // [SnapshotAborter], or the runtime will reject detach attempts. type SessionStore[State any] interface { SnapshotReader[State] @@ -212,7 +242,7 @@ func registerSnapshotActions[State any]( return } core.DefineAction(r, agentName, api.ActionTypeAgentSnapshot, nil, nil, - func(ctx context.Context, req *GetSnapshotRequest) (*GetSnapshotResponse[State], error) { + func(ctx context.Context, req *GetSnapshotRequest) (*SessionSnapshot[State], error) { if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId is required") } @@ -224,27 +254,27 @@ func registerSnapshotActions[State any]( return nil, core.NewError(core.NOT_FOUND, "getSnapshot: snapshot %q not found", req.SnapshotID) } - status := snap.Status - if status == "" { - status = SnapshotStatusSucceeded + // Return a normalized copy: the documented defaults (empty + // status means succeeded, zero UpdatedAt means CreatedAt) are + // resolved server-side so remote clients don't reimplement + // them, and the state transform shapes what leaves the server. + // A failed snapshot's state is its last-good state, so it is + // returned like any other. + resp := *snap + if resp.Status == "" { + resp.Status = SnapshotStatusSucceeded } - updatedAt := snap.UpdatedAt - if updatedAt.IsZero() { - updatedAt = snap.CreatedAt + if resp.UpdatedAt.IsZero() { + resp.UpdatedAt = resp.CreatedAt } - - resp := &GetSnapshotResponse[State]{ - SnapshotID: snap.SnapshotID, - CreatedAt: snap.CreatedAt, - UpdatedAt: updatedAt, - Status: status, - FinishReason: snap.FinishReason, - Error: snap.Error, + resp.State = applyTransform(ctx, transform, snap.State) + if resp.State != nil { + // SessionID is framework identity, not user data: re-stamp + // it from the row after the transform so outbound state + // always agrees with the snapshot it came from. + resp.State.SessionID = resp.SessionID } - if status != SnapshotStatusFailed && status != SnapshotStatusPending { - resp.State = applyTransform(ctx, transform, snap.State) - } - return resp, nil + return &resp, nil }) aborter, ok := store.(SnapshotAborter) @@ -271,8 +301,8 @@ func registerSnapshotActions[State any]( // --- Session --- -// Session holds conversation state and provides thread-safe read/write access to messages, -// input variables, custom state, and artifacts. +// Session holds conversation state and provides thread-safe read/write +// access to messages, custom state, and artifacts. type Session[State any] struct { mu sync.RWMutex state SessionState[State] @@ -280,6 +310,25 @@ type Session[State any] struct { version uint64 // incremented on every mutation; used to skip redundant snapshots } +// SessionID returns the ID of the session this conversation belongs to. +// The agent runtime settles it before the agent function runs: a fresh +// invocation mints a new ID, server-managed resumes inherit the chain's +// (a snapshot from before session IDs existed gets a fresh one), and a +// client-managed invocation keeps the ID carried inside the state object +// it was given ([SessionState.SessionID]), minting one if absent. +// +// The ID is stable for the lifetime of the invocation; it lives in +// [SessionState.SessionID], so it is stamped on every snapshot the +// invocation persists and rides inside the state returned to +// client-managed callers. It is safe to use as a key for external +// resources tied to the conversation, including from code that +// retrieves the session via [SessionFromContext]. +func (s *Session[State]) SessionID() string { + // Written once at construction, before fn runs and before the session + // is shared, then never mutated; safe to read without holding mu. + return s.state.SessionID +} + // State returns a copy of the current state. func (s *Session[State]) State() *SessionState[State] { s.mu.RLock() diff --git a/go/ai/exp/teststore_test.go b/go/ai/exp/teststore_test.go index 4a0e014c32..079dd0f2ed 100644 --- a/go/ai/exp/teststore_test.go +++ b/go/ai/exp/teststore_test.go @@ -26,6 +26,7 @@ package exp import ( "context" "encoding/json" + "errors" "fmt" "slices" "sync" @@ -67,6 +68,29 @@ func (s *testInMemStore[State]) GetSnapshot(_ context.Context, snapshotID string return testCopySnapshot(snap) } +func (s *testInMemStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*SessionSnapshot[State], error) { + if sessionID == "" { + return nil, errors.New("testInMemStore: session ID is empty") + } + s.mu.RLock() + defer s.mu.RUnlock() + var latest *SessionSnapshot[State] + for _, snap := range s.snapshots { + if snap.SessionID != sessionID || + snap.Status == SnapshotStatusFailed || snap.Status == SnapshotStatusAborted { + continue + } + if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) || + (snap.UpdatedAt.Equal(latest.UpdatedAt) && snap.SnapshotID > latest.SnapshotID) { + latest = snap + } + } + if latest == nil { + return nil, nil + } + return testCopySnapshot(latest) +} + func (s *testInMemStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (SnapshotStatus, error) { s.mu.Lock() defer s.mu.Unlock() @@ -115,6 +139,9 @@ func (s *testInMemStore[State]) SaveSnapshot( now := time.Now() if existing != nil { next.CreatedAt = existing.CreatedAt + if existing.SessionID != "" { + next.SessionID = existing.SessionID // a row's session never changes + } } else { next.CreatedAt = now } diff --git a/go/core/flow.go b/go/core/flow.go index 35d7c13766..b626d03a8e 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -197,9 +197,8 @@ func FlowNameFromContext(ctx context.Context) string { // call [Run] for sub-step tracking and see the flow name in spans — // without going through [NewBidiFlow] / [DefineBidiFlow]. // -// The Define*Flow constructors call this internally; direct callers -// only need it when bypassing those constructors to set custom -// [ActionOptions]. +// The flow constructors attach this context themselves; direct callers +// only need it when bypassing them, e.g. to set custom [ActionOptions]. func WithFlowContext(ctx context.Context, flowName string) context.Context { return flowContextKey.NewContext(ctx, &flowContext{flowName: flowName}) } diff --git a/go/core/schemas.config b/go/core/schemas.config index 60a4eb3573..7d34954c36 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1215,8 +1215,8 @@ accepted. The server writes a single pending snapshot (with empty state), returns [AgentOutput] with that snapshot ID, and continues processing any already-buffered inputs in a background context. The pending snapshot is finalized with the cumulative final state once all -queued inputs are processed (or the snapshot is cancelled via -cancelSnapshot). +queued inputs are processed (or the invocation is aborted via the +abortSnapshot companion action). . AgentInput.message type *ai.Message @@ -1262,25 +1262,48 @@ AgentInit typeparams [State any] AgentInit doc AgentInit is the input for starting an agent invocation. -Exactly one of SnapshotID or State may be set, and the choice must match -the agent's state management: - - Server-managed state (a session store is configured): callers must - use SnapshotID; sending State is rejected. - - Client-managed state (no session store): callers must use State; - sending SnapshotID is rejected. -Sending both fields is always rejected. Sending neither starts a fresh -invocation with empty state. +SessionID, SnapshotID, and State are competing conversation sources, +and which are valid depends on the agent's state management: + - Server-managed state (a session store is configured): callers may + use SessionID (resume the session's latest snapshot), SnapshotID + (resume a specific snapshot), or both (the session ID is asserted + against the snapshot); sending State is rejected. + - Client-managed state (no session store): callers may use State, + which carries the conversation's identity inside it + ([SessionState.SessionID]); sending SessionID or SnapshotID is + rejected. +Sending no fields starts a fresh invocation with empty state. +. + +AgentInit.sessionId doc +SessionID identifies the session (conversation) to resume. Only valid +when the agent is server-managed (a session store is configured); +mutually exclusive with State (a client-managed conversation carries +its identity inside [SessionState.SessionID]). Alone, it resumes the +session from its latest snapshot: the most recently updated one that +is not a failed/aborted dead end. A pending latest snapshot (a +detached invocation still running) rejects the resume rather than +racing the background work; if the session's history was forked by +resuming an earlier snapshot again, the most recently updated branch +wins, and SnapshotID can pick a branch explicitly. Combined with +SnapshotID, it asserts which session the snapshot belongs to, and a +mismatch is rejected. . AgentInit.snapshotId doc SnapshotID loads state from a persisted snapshot. Only valid when the -agent is server-managed (a session store is configured). Mutually -exclusive with State. +agent is server-managed (a session store is configured). May be +combined with SessionID to validate that the snapshot belongs to that +session. Mutually exclusive with State. . AgentInit.state doc State provides direct state for the invocation. Only valid when the -agent is client-managed (no session store). Mutually exclusive with +agent is client-managed (no session store). The conversation's +identity rides inside it ([SessionState.SessionID]): the framework +mints one on the conversation's first invocation and echoes it on the +output state, so resending the state object keeps the identity without +tracking a separate field. Mutually exclusive with SessionID and SnapshotID. . @@ -1322,12 +1345,26 @@ AgentOutput is the output when an agent invocation completes. It wraps AgentResult with framework-managed fields. . +AgentOutput.sessionId doc +SessionID is the ID of the session this invocation belongs to, +assigned by the framework when the invocation starts. With +server-managed state, a fresh invocation mints a new ID, resumed +invocations inherit the chain's, and resuming a snapshot from before +session IDs existed mints a fresh one. With client-managed state it +echoes the ID carried inside the state object +([SessionState.SessionID]), minting one on the conversation's first +invocation; only a session with persisted snapshots can be resumed by +this ID. +. + AgentOutput.snapshotId doc -SnapshotID is the ID of the snapshot created at the end of this invocation. -Empty if no snapshot was created (callback returned false or no store configured). -When FinishReason is [AgentFinishReasonFailed] (and a store is configured), -it is the most recent snapshot capturing the last-good state: everything -through the last successful turn (see [SnapshotEventRecovery]). +SnapshotID is the ID of the newest snapshot capturing this invocation: +the invocation-end snapshot, or the latest earlier snapshot when that +write was skipped. Empty when no store is configured or the invocation +persisted nothing. When FinishReason is [AgentFinishReasonDetached] it +is the pending detach snapshot; when [AgentFinishReasonFailed], the most +recent snapshot capturing the last-good state: everything through the +last successful turn (see [SnapshotEventRecovery]). . AgentOutput.state doc @@ -1417,8 +1454,9 @@ snapshot was persisted. TurnEnd.snapshotId doc SnapshotID is the ID of the snapshot persisted at the end of this turn. -Empty if no snapshot was created (callback returned false or no store -configured, or snapshots were suspended after detach). +Empty if no snapshot was written (no store configured, the callback +declined, nothing changed since the last snapshot, or snapshots were +suspended after detach). . TurnEnd.finishReason doc @@ -1445,6 +1483,15 @@ SessionState is the portable conversation state that flows between client and server. It contains only the data needed for conversation continuity. . +SessionState.sessionId doc +SessionID is the ID of the session (conversation) this state belongs to. +Framework-owned: assigned when the conversation's first invocation +starts and re-stamped on outbound state, so client-managed callers can +round-trip the state object opaquely without tracking a separate +identifier. For server-managed agents the snapshot row's +[SessionSnapshot.SessionID] is canonical and this field mirrors it. +. + SessionState.messages type []*ai.Message SessionState.messages doc Messages is the conversation history (user/model exchanges). @@ -1477,8 +1524,19 @@ SessionSnapshot.snapshotId doc SnapshotID is the unique identifier for this snapshot (UUID). . +SessionSnapshot.sessionId doc +SessionID is the ID of the session this snapshot belongs to. Assigned +by the agent framework when the conversation's first invocation starts +and stamped on every later snapshot in the chain, including across +resumed invocations. Stores preserve it across rewrites; rows written +without one (data from before session IDs existed) belong to no +session. +. + SessionSnapshot.parentId doc -ParentID is the ID of the previous snapshot in this timeline. +ParentID is the ID of the previous snapshot in this timeline. It is +informational lineage (for debugging and UI history trees) and plays +no part in resolving a session's latest snapshot. . SessionSnapshot.createdAt type time.Time @@ -1537,24 +1595,25 @@ SnapshotEvent identifies what triggered a snapshot. . SnapshotEventTurnEnd doc -TurnEnd indicates the snapshot was triggered at the end of a turn. +SnapshotEventTurnEnd indicates the snapshot was triggered at the end of a turn. . SnapshotEventInvocationEnd doc -InvocationEnd indicates the snapshot was triggered at the end of the invocation. +SnapshotEventInvocationEnd indicates the snapshot was triggered at the end +of the invocation. . SnapshotEventDetach doc -Detach indicates the snapshot was created when the client detached the -invocation and the flow continues in the background. The snapshot is -initially written with [SnapshotStatusPending] and rewritten with a -terminal status once the background work finishes. +SnapshotEventDetach indicates the snapshot was created when the client +detached the invocation and the work continues in the background. The +snapshot is initially written with [SnapshotStatusPending] and rewritten +with a terminal status once the background work finishes. . SnapshotEventRecovery doc -Recovery indicates the snapshot was written retroactively by the failure -path to preserve the last-good state (everything through the last -successful turn) when a selective snapshot callback had skipped +SnapshotEventRecovery indicates the snapshot was written retroactively +by the failure path to preserve the last-good state (everything through +the last successful turn) when a selective snapshot callback had skipped persisting it. It is a normal [SnapshotStatusSucceeded] row carrying the last good turn's finish reason, resumable like any other; the snapshot callback is bypassed and never sees this event. @@ -1585,7 +1644,7 @@ meantime. SnapshotStatusPending doc SnapshotStatusPending indicates a detached invocation is still processing the queued inputs. The snapshot will be rewritten with a -terminal status once the flow exits. +terminal status once the background work finishes. . SnapshotStatusSucceeded doc @@ -1593,8 +1652,8 @@ SnapshotStatusSucceeded indicates the snapshot captures a settled state. . SnapshotStatusAborted doc -SnapshotStatusAborted indicates the snapshot's invocation was -aborted via the abortSnapshot companion action while detached. +SnapshotStatusAborted indicates the snapshot's invocation was aborted +while detached (e.g. via the abortSnapshot companion action). . SnapshotStatusFailed doc @@ -1672,9 +1731,11 @@ GetSnapshotRequest pkg ai/exp GetSnapshotRequest doc GetSnapshotRequest is the input for an agent's getSnapshot companion -action. The action is registered at `{agentName}/getSnapshot` when the -agent is defined and is intended for Dev UI and client-side reconnect -flows. +action, registered under the agent's name (action type agent-snapshot) +when the agent has a session store configured. The action is intended +for Dev UI and client-side reconnect flows. It returns the stored +[SessionSnapshot], with [WithStateTransform] applied to its state if +configured. . GetSnapshotRequest.snapshotId noomitempty @@ -1682,59 +1743,6 @@ GetSnapshotRequest.snapshotId doc SnapshotID identifies the snapshot to fetch. . -# GetSnapshotResponse -GetSnapshotResponse pkg ai/exp -GetSnapshotResponse typeparams [State any] - -GetSnapshotResponse doc -GetSnapshotResponse is the output of the getSnapshot companion action. It -is a client-facing view of the stored snapshot: identifying metadata plus -the session state, with [WithStateTransform] applied if configured. - -Unlike the raw [SessionSnapshot], this response intentionally omits -internal fields (parent ID, event) and does not leak the snapshot -envelope beyond what callers need to repopulate a UI. -. - -GetSnapshotResponse.snapshotId noomitempty -GetSnapshotResponse.snapshotId doc -SnapshotID echoes the requested snapshot ID. -. - -GetSnapshotResponse.createdAt type time.Time -GetSnapshotResponse.createdAt doc -CreatedAt is when the snapshot record was first written. -. - -GetSnapshotResponse.updatedAt type time.Time -GetSnapshotResponse.updatedAt doc -UpdatedAt is when the snapshot record was last written. Equals -CreatedAt for snapshots that have not been rewritten. -. - -GetSnapshotResponse.status doc -Status is the lifecycle state of the snapshot. See [SnapshotStatus]. -. - -GetSnapshotResponse.finishReason doc -FinishReason is the semantic reason the captured turn or invocation ended -(e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], -[AgentFinishReasonFailed], [AgentFinishReasonAborted]). It lets a remote or -background poller report how a detached or resumed invocation ended without -re-deriving it. Empty on a pending snapshot. -. - -GetSnapshotResponse.error type *core.GenkitError -GetSnapshotResponse.error doc -Error is the structured failure information; populated when Status -is [SnapshotStatusFailed]. -. - -GetSnapshotResponse.state doc -State is the session state captured by the snapshot, after any -configured transform. Empty when Status is pending or error. -. - # AbortSnapshotRequest AbortSnapshotRequest pkg ai/exp diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 96c1a8e4f5..01b5c5fd22 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -509,9 +509,10 @@ func DefineAgent[State any]( // // # Options // -// - [aix.WithSessionStore]: Enable snapshot persistence with a [aix.SessionStore] +// - [aix.WithSessionStore]: Enable snapshot persistence // - [aix.WithSnapshotCallback]: Control when snapshots are created // - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types +// - [aix.WithStateTransform]: Rewrite session state on its way out to the client // // Type parameters: // - Stream: Type for custom status updates sent via [aix.Responder.SendStatus] diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index 8d14dd4801..77b0cbb6bb 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -751,7 +751,8 @@ type extraField struct { // type EXPR // use EXPR for the type expression (for fields only) // doc -// doc is following lines until the line "." +// doc is following lines until the line "."; leading whitespace +// is preserved so godoc lists survive generation // pkg // package path, relative to outdir (last component is package name) // import @@ -790,7 +791,10 @@ func parseConfigFile(filename string) (config, error) { if line == "." { docItem = nil } else { - docItem.docLines = append(docItem.docLines, line) + // Keep leading whitespace: doc blocks may contain godoc + // lists, whose items and continuation lines must stay + // indented to render as lists. + docItem.docLines = append(docItem.docLines, strings.TrimRight(string(ln), " \t")) } continue } diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index 986cde7405..4f403a5ea0 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -264,6 +264,25 @@ func pickSession(ctx context.Context, inputCh <-chan string, a sampleAgent, late } } +// resumeOption picks how to resume the chosen snapshot. Resuming by +// session ID is the canonical "continue this conversation" flow, so the +// session is validated against the store first: if it does not resolve +// to a resumable snapshot (a detached invocation still pending, legacy +// rows with no session ID, or a store error), fall back to the exact +// snapshot the user was just shown. Validating up front keeps the chat +// from opening on a connection whose invocation already failed, which +// would surface the error only after the user types a message. +func resumeOption(ctx context.Context, a sampleAgent, resume *aix.SessionSnapshot[any]) aix.InvocationOption[any] { + if resume.SessionID != "" { + tip, err := a.Store.GetLatestSnapshot(ctx, resume.SessionID) + if err == nil && tip != nil && tip.Status != aix.SnapshotStatusPending { + return aix.WithSessionID[any](resume.SessionID) + } + fmt.Println("(this conversation's session can't be resumed as a whole; continuing from the selected snapshot)") + } + return aix.WithSnapshotID[any](resume.SnapshotID) +} + // runChat opens the bidi connection (optionally resuming from a // snapshot) and runs the per-turn REPL. When resuming, the prior // conversation is replayed first so the user sees the context they're @@ -287,7 +306,7 @@ func runChat(ctx context.Context, inputCh <-chan string, a sampleAgent, resume * var opts []aix.InvocationOption[any] if resume != nil { - opts = append(opts, aix.WithSnapshotID[any](resume.SnapshotID)) + opts = append(opts, resumeOption(ctx, a, resume)) } conn, err := a.StreamBidi(ctx, opts...) if err != nil { diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index e0a8fc61c0..bb3b97e795 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -146,6 +146,7 @@ class AgentInit(GenkitModel): """Model for agentinit data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + session_id: str | None = None snapshot_id: str | None = None state: SessionState | None = None @@ -171,6 +172,7 @@ class AgentOutput(GenkitModel): """Model for agentoutput data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + session_id: str | None = None snapshot_id: str | None = None state: SessionState | None = None message: MessageData | None = None @@ -214,24 +216,12 @@ class GetSnapshotRequest(GenkitModel): snapshot_id: str = Field(...) -class GetSnapshotResponse(GenkitModel): - """Model for getsnapshotresponse data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str = Field(...) - created_at: str | None = None - updated_at: str | None = None - status: SnapshotStatus | None = None - finish_reason: AgentFinishReason | None = None - error: Any | None = Field(default=None) - state: SessionState | None = None - - class SessionSnapshot(GenkitModel): """Model for sessionsnapshot data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) snapshot_id: str = Field(...) + session_id: str | None = None parent_id: str | None = None created_at: str = Field(...) updated_at: str | None = None @@ -246,6 +236,7 @@ class SessionState(GenkitModel): """Model for sessionstate data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + session_id: str | None = None messages: list[MessageData] | None = None custom: Any | None = Field(default=None) artifacts: list[Artifact] | None = None From 165293caa9d860ebbf315d8addc2b9a628a0d8dc Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 10 Jun 2026 15:54:38 -0700 Subject: [PATCH 092/141] fix(go/exp): prevent agent process crashes from shared prompt metadata and nil inputs agentLoop tagged prompt-rendered messages by writing into Metadata maps that Render can alias from the registered prompt's stored config (e.g. messages registered via ai.WithMessages): concurrent invocations died on fatal concurrent map writes, and the _genkit_prompt tag leaked into the prompt config. Tag cloned copies of the base messages instead. detachIntake dereferenced nil inputs in its unrecovered reader goroutine, so conn.Send(nil) (or a transport-decoded JSON null) crashed the process. Reject nil at AgentConnection.Send and skip nils in the intake read and detach-drain paths, where an enqueued nil also ended the input stream early and dropped queued inputs. --- go/ai/exp/agent.go | 39 ++++++++-- go/ai/exp/agent_test.go | 159 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+), 7 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 99c3e5ff34..935a0a0a83 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -25,6 +25,7 @@ import ( "errors" "fmt" "iter" + "maps" "runtime/debug" "sync" "sync/atomic" @@ -1342,6 +1343,13 @@ func (i *detachIntake) read() { if !ok { return } + if input == nil { + // A nil input (e.g. a JSON null decoded by a transport) + // carries nothing to process; dropping it here also keeps + // nil out of the queue, where the forwarder would read it + // as end-of-input. + continue + } if input.Detach { i.handleDetach(input) return @@ -1381,7 +1389,9 @@ drainLoop: if !ok { break drainLoop } - drained = append(drained, more) + if more != nil { + drained = append(drained, more) + } default: break drainLoop } @@ -1542,16 +1552,28 @@ func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) Ag } // Tag base messages so they can be filtered out of session - // history after generation. + // history after generation. Tag copies rather than the + // rendered messages themselves: Render can alias message + // metadata from shared prompt config (e.g. messages + // registered via [ai.WithMessages]), so tagging in place + // would leak the tag into the registered prompt and race + // with concurrent invocations. + base := make([]*ai.Message, 0, len(actionOpts.Messages)) for _, m := range actionOpts.Messages { - if m.Metadata == nil { - m.Metadata = make(map[string]any) + if m == nil { + continue } - m.Metadata[promptMessageKey] = true + tagged := *m + tagged.Metadata = maps.Clone(tagged.Metadata) + if tagged.Metadata == nil { + tagged.Metadata = make(map[string]any, 1) + } + tagged.Metadata[promptMessageKey] = true + base = append(base, &tagged) } // Append conversation history after the base messages. - actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) + actionOpts.Messages = append(base, sess.Messages()...) // If a resume payload was provided, forward it to the // generate call so handleResumeOption re-executes the @@ -1709,13 +1731,16 @@ type AgentConnection[Stream, State any] struct { conn *core.BidiConnection[*AgentInput, *AgentStreamChunk[Stream], *AgentOutput[State]] } -// Send sends an AgentInput to the agent. +// Send sends an AgentInput to the agent. The input must not be nil. // // Once the invocation has resolved (e.g. a failed turn ended it), Send // fails with an error matching [core.ErrActionCompleted]; the outcome is // on [AgentConnection.Output]. The same applies to the SendMessage, // SendText, SendResume, and Detach helpers. func (c *AgentConnection[Stream, State]) Send(input *AgentInput) error { + if input == nil { + return core.NewError(core.INVALID_ARGUMENT, "agent input must not be nil") + } return c.conn.Send(input) } diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index b9a31a9d79..2ea1b162d4 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -5470,3 +5470,162 @@ func TestAgent_GetSnapshotAction_ReturnsSessionID(t *testing.T) { t.Errorf("getSnapshot SessionID = %q, want %q", resp.SessionID, out.SessionID) } } + +func TestPromptAgent_InlineMessages_DoesNotMutateSharedMetadata(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + // Render aliases the metadata map of messages registered via + // WithMessages, so the agent loop must not tag the rendered + // messages in place. + shared := ai.NewModelTextMessage("inline context message") + shared.Metadata = map[string]any{"origin": "config"} + + af := DefineAgent[testState](reg, "inlineMetaPrompt", FromInline( + ai.WithModelName("test/echo"), + ai.WithMessages(shared), + )) + + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + if _, ok := shared.Metadata[promptMessageKey]; ok { + t.Errorf("prompt message tag leaked into shared config message metadata: %v", shared.Metadata) + } + // The base message must still be filtered out of session history: + // 1 user message + 1 model reply = 2. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptAgent_InlineMessages_ConcurrentInvocations(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + // All invocations render the same inline message whose metadata map + // is shared with the registered prompt's config; tagging it in + // place is a concurrent map write under the race detector. + shared := ai.NewModelTextMessage("inline context message") + shared.Metadata = map[string]any{"origin": "config"} + + af := DefineAgent[testState](reg, "inlineConcurrentPrompt", FromInline( + ai.WithModelName("test/echo"), + ai.WithMessages(shared), + )) + + var wg sync.WaitGroup + errs := make(chan error, 8) + for range 8 { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := af.RunText(ctx, "hello"); err != nil { + errs <- err + } + }() + } + wg.Wait() + close(errs) + for err := range errs { + t.Errorf("RunText failed: %v", err) + } +} + +func TestAgent_SendNilInput_Rejected(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "nilInputFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + if input.Message != nil { + sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Message.Text())) + } + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + if err := conn.Send(nil); err == nil { + t.Error("expected Send(nil) to fail") + } + + // The connection must remain usable after the rejected input. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.TurnEnd != nil { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + } +} + +// TestDetachIntake_SkipsNilInputs covers nil inputs that bypass the typed +// Send API (e.g. a JSON null decoded by a transport): the intake must drop +// them rather than crash its reader goroutine or end the stream early. +func TestDetachIntake_SkipsNilInputs(t *testing.T) { + t.Run("read path", func(t *testing.T) { + src := make(chan *AgentInput, 4) + src <- nil + src <- &AgentInput{Message: ai.NewUserTextMessage("one")} + src <- nil + close(src) + + intake := startDetachIntake(src) + defer intake.stopAndWait() + + var got []string + for in := range intake.out() { + got = append(got, in.Message.Text()) + intake.releaseForward() + } + if !slices.Equal(got, []string{"one"}) { + t.Errorf("expected [one], got %v", got) + } + }) + + t.Run("detach drain path", func(t *testing.T) { + src := make(chan *AgentInput, 4) + src <- &AgentInput{Detach: true, Message: ai.NewUserTextMessage("final")} + src <- nil + src <- &AgentInput{Message: ai.NewUserTextMessage("two")} + close(src) + + intake := startDetachIntake(src) + defer intake.stopAndWait() + go func() { <-intake.detachSignal() }() + + var got []string + for in := range intake.out() { + got = append(got, in.Message.Text()) + intake.releaseForward() + } + if !slices.Equal(got, []string{"final", "two"}) { + t.Errorf("expected [final two], got %v", got) + } + }) +} From 670ccc2783689793ca465dd206496dc59baf4159 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 11 Jun 2026 08:59:26 -0700 Subject: [PATCH 093/141] fix(go/core): release bidi connection contexts and unwedge abandoned streams Every StreamBidi invocation derived a cancelCtx whose cancel was never invoked, so completed connections accumulated as children on long-lived parent contexts, and a connection abandoned mid-stream parked its action goroutine on the full stream buffer forever with no escape hatch (the Receive doc pointed at Close, which only half-closes the input side). Core BidiConnection: - Cancel the connection context once the action returns, releasing the parent registration. The defer ordering plus done re-checks in Send/Receive/Output keep completion from surfacing as a spurious cancellation error. - Output now drains unconsumed stream chunks, so break-then-Output lets a parked action run to completion instead of wedging both sides. - New Cancel method as the explicit terminate; Close and Receive docs now describe what they actually do. Agent runtime: - The chunk router's forward watches the action context, so a disconnected client cannot leave handleFnDone blocked in close() behind a chunk nobody will read (fn returning does not imply the router is idle; the old comment claimed otherwise). - AgentConnection.Output delegates its terminal wait to the core Output, gaining the ctx escape it lacked. --- go/ai/exp/agent.go | 63 +++++++------ go/ai/exp/agent_test.go | 115 +++++++++++++++++++++++ go/core/action.go | 119 ++++++++++++++++++------ go/core/flow_test.go | 199 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 441 insertions(+), 55 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 935a0a0a83..2e18fea8f4 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -622,7 +622,7 @@ func newAgentRuntime[Stream, State any]( name: name, cfg: cfg, session: session, - router: startChunkRouter(session, outCh), + router: startChunkRouter(ctx, session, outCh), intake: startDetachIntake(inCh), fnDone: make(chan fnDoneResult[State], 1), } @@ -771,14 +771,16 @@ func (rt *agentRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFun // an error, the invocation resolves gracefully as a failed output instead // (see failedOutput). // -// When fn returns with an error, the Responder's ctx-aware send may have -// dropped a chunk while the router was still pinned on a downstream send -// to a slow/gone consumer. router.close blocks on the router's forward -// goroutine exiting, which can't happen while it's stuck on that send; -// stopAndWait closes stopWriting first so the router breaks out and -// enters drain mode. The natural-completion path leaves the router idle -// (every Send was accepted before fn returned), so close alone is -// sufficient there and avoids trashing a last in-flight chunk. +// router.close blocks on the forward goroutine exiting, and fn returning +// does not imply the router is idle: fn's last accepted chunk may still be +// in the router's hands, parked on the send to a full out buffer. On the +// error path stopAndWait closes stopWriting first, deliberately dropping +// the failed turn's in-flight chunks so close cannot wedge behind a +// slow/gone consumer. On the success path those chunks are wanted, so +// close relies on the parked send being released instead: a consuming +// client drains it, a disconnected client's ctx cancellation trips +// forward's ctx arm, and a client that stopped receiving unparks it when +// its Output call drains the stream. func (rt *agentRuntime[Stream, State]) handleFnDone( ctx context.Context, cancelWork context.CancelFunc, @@ -1131,6 +1133,7 @@ func resumeSessionFrom[State any](s *Session[State], snap *SessionSnapshot[State // send. type chunkRouter[Stream, State any] struct { + ctx context.Context // action context; ends on client disconnect (or completion) in chan *AgentStreamChunk[Stream] out chan<- *AgentStreamChunk[Stream] session *Session[State] @@ -1144,10 +1147,12 @@ type chunkRouter[Stream, State any] struct { } func startChunkRouter[Stream, State any]( + ctx context.Context, session *Session[State], out chan<- *AgentStreamChunk[Stream], ) *chunkRouter[Stream, State] { r := &chunkRouter[Stream, State]{ + ctx: ctx, in: make(chan *AgentStreamChunk[Stream]), out: out, session: session, @@ -1162,13 +1167,14 @@ func startChunkRouter[Stream, State any]( func (r *chunkRouter[Stream, State]) run() { defer close(r.done) if !r.forward() { - // r.in closed before detach; nothing left to do. + // r.in closed while writes were still allowed; nothing left to do. return } close(r.writerStopped) - // Detached: keep applying side effects so the user fn's - // SendArtifact/SendModelChunk calls behave the same way they did - // pre-detach. Only the wire forward to outCh is suppressed. + // Writes stopped (detach, shutdown, or client disconnect): keep + // applying side effects so the user fn's SendArtifact/SendModelChunk + // calls behave the same way they did before. Only the wire forward to + // outCh is suppressed. for chunk := range r.in { r.applySideEffects(chunk) } @@ -1188,8 +1194,10 @@ func (r *chunkRouter[Stream, State]) applySideEffects(chunk *AgentStreamChunk[St } } -// forward delivers chunks to outCh and applies side effects until detach -// or r.in closes. Returns true if it stopped because of detach. +// forward delivers chunks to outCh and applies side effects until told to +// stop writing, the action context ends, or r.in closes. Returns true if +// the router must keep draining side effects (writes stopped), false if +// r.in closed. func (r *chunkRouter[Stream, State]) forward() bool { for { select { @@ -1202,6 +1210,12 @@ func (r *chunkRouter[Stream, State]) forward() bool { case r.out <- chunk: case <-r.stopWriting: return true + case <-r.ctx.Done(): + // The client is gone (disconnect cancels the action + // context), so nothing will drain out again and a blocked + // forward would wedge close. Drop the chunk and switch to + // side-effects-only mode. + return true } case <-r.stopWriting: return true @@ -1790,8 +1804,9 @@ func (c *AgentConnection[Stream, State]) Close() error { // Receive returns an iterator for receiving stream chunks. Breaking out // of the iterator does not cancel the connection; multi-turn callers // routinely break on [TurnEnd], send the next input, then call Receive -// again to consume the next batch. Use ctx cancellation or [Close] to -// terminate the connection. +// again to consume the next batch. Call [AgentConnection.Output] to +// finish the invocation, or cancel the ctx passed to StreamBidi to +// abort it. func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[Stream], error] { return c.conn.Receive() } @@ -1821,16 +1836,10 @@ func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[S // them. Finish Receive first, then call Output. func (c *AgentConnection[Stream, State]) Output() (*AgentOutput[State], error) { _ = c.conn.Close() - - drainDone := make(chan struct{}) - go func() { - defer close(drainDone) - for range c.conn.Receive() { - } - }() - - <-c.conn.Done() - <-drainDone + // The core Output drains unconsumed chunks (so the agent is never + // wedged publishing to a stream nobody reads) and unblocks on ctx + // cancellation while preferring the finalized result when both are + // ready. return c.conn.Output() } diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 2ea1b162d4..945a376db0 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -5629,3 +5629,118 @@ func TestDetachIntake_SkipsNilInputs(t *testing.T) { } }) } + +// TestAgent_ClientCancelMidStream reproduces an invocation hang: fn +// returns nil (the closed input stream ended sess.Run cleanly) while its +// last accepted chunk is still in the router's hands, parked on the full +// stream buffer, and the client then cancels instead of draining. The +// fn-done success path skips stopAndWait, so the router's forward must +// observe the cancelled action context itself (nothing will drain the +// stream again) for the invocation to resolve, and Output must unblock +// rather than waiting on completion unconditionally. +func TestAgent_ClientCancelMidStream(t *testing.T) { + for i := range 10 { + t.Run(fmt.Sprintf("iteration%d", i), func(t *testing.T) { + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "cancelFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + resp.SendStatus(testStatus{Phase: "step0"}) + resp.SendStatus(testStatus{Phase: "step1"}) + return nil, nil + }) + }, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + // Close the input side so sess.Run ends cleanly and fn returns + // nil once its sends are accepted. + conn.Close() + + // Consume a single chunk: that frees a buffer slot so the fn's + // remaining sends are accepted and fn returns, leaving the + // router parked mid-forward on the full stream buffer. + for range conn.Receive() { + break + } + // Give the agent time to reach that parked state, then cancel + // instead of draining. + time.Sleep(50 * time.Millisecond) + cancel() + + outputDone := make(chan struct{}) + go func() { + defer close(outputDone) + conn.Output() + }() + select { + case <-outputDone: + case <-time.After(10 * time.Second): + t.Fatal("Output did not return after client cancellation") + } + + // The invocation itself must also resolve: a wedged router + // would leave the action goroutine (and its trace span) open + // forever. + select { + case <-conn.Done(): + case <-time.After(10 * time.Second): + t.Fatal("invocation did not complete after client cancellation") + } + }) + } +} + +// TestAgent_OutputUnblocksOnCancel covers the caller's escape hatch when +// the agent fn does not observe cancellation: Output must return once the +// connection's context is cancelled instead of blocking on completion +// that may never come. +func TestAgent_OutputUnblocksOnCancel(t *testing.T) { + reg := newTestRegistry(t) + + block := make(chan struct{}) + t.Cleanup(func() { close(block) }) // let the stubborn fn unwind + + af := DefineCustomAgent(reg, "stubbornFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + <-block // ignores ctx + return nil, nil + }, + ) + + ctx, cancel := context.WithCancel(context.Background()) + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + cancel() + + type result struct { + out *AgentOutput[testState] + err error + } + resultCh := make(chan result, 1) + go func() { + out, err := conn.Output() + resultCh <- result{out, err} + }() + + select { + case res := <-resultCh: + if res.err == nil { + t.Errorf("expected an error from Output after cancellation, got output %+v", res.out) + } + case <-time.After(10 * time.Second): + t.Fatal("Output did not return after cancellation; no context escape") + } +} diff --git a/go/core/action.go b/go/core/action.go index 8c3162083b..34c880a549 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -434,6 +434,14 @@ func (a *Action[In, Out, StreamOut, StreamIn]) StreamBidi(ctx context.Context, i } go func() { + // Cancel the connection's context once the action has returned, so a + // completed connection releases its registration on the parent + // context (a long-lived parent would otherwise accumulate one child + // per invocation). The LIFO defer order guarantees that whenever + // ctx.Done is observed because of completion, streamCh and doneCh + // are already closed, letting Send/Receive/Output re-check them and + // report completion rather than a spurious cancellation. + defer conn.cancel() defer close(conn.doneCh) defer close(conn.streamCh) output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, in, @@ -568,13 +576,23 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Send(input StreamIn) (err err case c.inputCh <- input: return nil case <-c.ctx.Done(): - return c.ctx.Err() + // The context is also cancelled on completion; report the + // completion error when the action has already returned. + select { + case <-c.doneCh: + return NewError(FAILED_PRECONDITION, "%v", ErrActionCompleted) + default: + return c.ctx.Err() + } case <-c.doneCh: return NewError(FAILED_PRECONDITION, "%v", ErrActionCompleted) } } -// Close signals that no more inputs will be sent. +// Close signals that no more inputs will be sent. It does not terminate +// the connection: the action keeps running until it returns on its own, +// typically after observing the closed input stream. To abandon a +// connection and the work behind it, use [BidiConnection.Cancel]. func (c *BidiConnection[StreamIn, StreamOut, Out]) Close() error { c.mu.Lock() defer c.mu.Unlock() @@ -586,13 +604,24 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Close() error { return nil } +// Cancel terminates the connection: it cancels the context the action +// runs under and releases callers blocked on Send, Receive, or Output. +// The action's shutdown is asynchronous; a well-behaved action observes +// the cancellation and returns promptly. Cancel is safe to call multiple +// times and after completion, where it has no effect. +func (c *BidiConnection[StreamIn, StreamOut, Out]) Cancel() { + c.cancel() +} + // Receive returns an iterator for receiving streamed response chunks. // The iterator yields chunks until the action finishes, the context is // cancelled, or the caller breaks out of the loop. Breaking out does NOT // cancel the connection: bidi callers routinely break to switch to -// sending, then call Receive again to consume the next batch. Use ctx -// cancellation or [BidiConnection.Close] to terminate the connection -// (matching gRPC and similar bidi streaming conventions). +// sending, then call Receive again to consume the next batch (matching +// gRPC and similar bidi streaming conventions). To finish with the +// connection, call [BidiConnection.Output]; to abandon it early, use +// [BidiConnection.Cancel] or cancel the context. A connection that is +// merely abandoned leaves the action running. func (c *BidiConnection[StreamIn, StreamOut, Out]) Receive() iter.Seq2[StreamOut, error] { return func(yield func(StreamOut, error) bool) { for { @@ -605,9 +634,25 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Receive() iter.Seq2[StreamOut return } case <-c.ctx.Done(): - var zero StreamOut - yield(zero, c.ctx.Err()) - return + // The context is also cancelled on completion, and chunks + // may remain buffered either way. Deliver them, and end + // cleanly if the stream closed: cancellation is reported + // only while the action is still running. + for { + select { + case chunk, ok := <-c.streamCh: + if !ok { + return + } + if !yield(chunk, nil) { + return + } + default: + var zero StreamOut + yield(zero, c.ctx.Err()) + return + } + } } } } @@ -615,29 +660,47 @@ func (c *BidiConnection[StreamIn, StreamOut, Out]) Receive() iter.Seq2[StreamOut // Output returns the final output after the action completes. // Blocks until done or context cancelled. If the action has finished, its -// actual output is returned even when the context was cancelled concurrently. +// actual output is returned even when the context was cancelled concurrently +// (e.g., session flows backgrounded on client disconnect). +// +// Calling Output declares that the caller is done receiving: stream chunks +// not consumed via Receive are drained and discarded, so an action blocked +// publishing to the full stream buffer can run to completion instead of +// wedging both sides. Do not call Output concurrently with a goroutine +// iterating Receive; both consume from the same stream and chunks would be +// split between them. Finish Receive first, then call Output. func (c *BidiConnection[StreamIn, StreamOut, Out]) Output() (Out, error) { - // Fast path: if the action has already finished, return its output - // rather than racing with ctx.Done. This matters for callers that - // observe a completed action just after cancelling ctx (e.g., session - // flows backgrounded on client disconnect). - select { - case <-c.doneCh: - c.mu.Lock() - defer c.mu.Unlock() - return c.output, c.err - default: + for { + select { + case <-c.doneCh: + return c.result() + case _, ok := <-c.streamCh: + if !ok { + // The stream closed: the action has returned and doneCh + // closes immediately after. + <-c.doneCh + return c.result() + } + case <-c.ctx.Done(): + // The context is also cancelled on completion; prefer the + // action's actual output when it has already finished. + select { + case <-c.doneCh: + return c.result() + default: + var zero Out + return zero, c.ctx.Err() + } + } } +} - select { - case <-c.doneCh: - c.mu.Lock() - defer c.mu.Unlock() - return c.output, c.err - case <-c.ctx.Done(): - var zero Out - return zero, c.ctx.Err() - } +// result returns the stored output and error. Callers must have observed +// doneCh closed first. +func (c *BidiConnection[StreamIn, StreamOut, Out]) result() (Out, error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.output, c.err } // Done returns a channel that is closed when the connection completes. diff --git a/go/core/flow_test.go b/go/core/flow_test.go index f240d62bd1..5dcfdd84b4 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -23,6 +23,7 @@ import ( "slices" "strings" "testing" + "time" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" @@ -442,6 +443,204 @@ func TestBidiConnectionContextCancellation(t *testing.T) { } } +func TestBidiConnectionCompletionReleasesContext(t *testing.T) { + ctx := context.Background() + + // Capture the context the action runs under: completion must cancel it + // so the connection releases its registration on the parent context + // (a long-lived parent would otherwise accumulate one child per + // invocation). + ctxCh := make(chan context.Context, 1) + action := NewBidiAction( + "capture", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + ctxCh <- ctx + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + <-conn.Done() + fnCtx := <-ctxCh + + // The cancel fires just after doneCh closes; poll briefly. + deadline := time.Now().Add(5 * time.Second) + for fnCtx.Err() == nil { + if time.Now().After(deadline) { + t.Fatal("connection context was not cancelled after completion") + } + time.Sleep(time.Millisecond) + } +} + +func TestBidiConnectionNoSpuriousErrorAfterCompletion(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "clean", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + outCh <- "chunk" + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + <-conn.Done() + + // Completion cancels the connection context, so the stream-closed and + // ctx-done select arms are both ready; Receive and Output must prefer + // the stream's own terminal state over a spurious cancellation error. + // Iterate to defeat select randomness. + for range 30 { + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive yielded error on completed connection: %v", err) + } + if chunk != "chunk" { + t.Fatalf("unexpected chunk %q", chunk) + } + } + output, err := conn.Output() + if err != nil { + t.Fatalf("Output errored on completed connection: %v", err) + } + if output != "done" { + t.Fatalf("expected output 'done', got %q", output) + } + } +} + +func TestBidiConnectionOutputUnblocksAbandonedStream(t *testing.T) { + ctx := context.Background() + + // The action is not ctx-aware and emits more chunks than the caller + // consumes, so it parks on the full stream buffer when the caller + // breaks out of Receive. Output must drain the stream so the action + // can run to completion instead of wedging both sides. + action := NewBidiAction( + "chatty", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range 3 { + outCh <- "chunk" + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + for range conn.Receive() { + break // abandon the stream after one chunk + } + conn.Close() + + type result struct { + output string + err error + } + resultCh := make(chan result, 1) + go func() { + output, err := conn.Output() + resultCh <- result{output, err} + }() + + select { + case res := <-resultCh: + if res.err != nil { + t.Fatalf("Output errored: %v", res.err) + } + if res.output != "done" { + t.Fatalf("expected output 'done', got %q", res.output) + } + case <-time.After(10 * time.Second): + t.Fatal("Output did not return; action wedged on abandoned stream") + } +} + +func TestBidiConnectionCancel(t *testing.T) { + ctx := context.Background() + + started := make(chan struct{}) + action := NewBidiAction( + "blocking", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + close(started) + <-ctx.Done() + return "", ctx.Err() + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + <-started + + conn.Cancel() + + if _, err := conn.Output(); err == nil { + t.Error("expected error after Cancel") + } + conn.Cancel() // idempotent +} + +func TestBidiConnectionReceiveResumesAfterBreak(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "echo", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + outCh <- "echo: " + input + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + // Breaking out of Receive must not terminate the connection: the + // send-receive-break cycle is the canonical multi-turn pattern. + var chunks []string + for _, input := range []string{"one", "two"} { + if err := conn.Send(input); err != nil { + t.Fatalf("Send(%q) failed: %v", input, err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + chunks = append(chunks, chunk) + break + } + } + conn.Close() + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "done" { + t.Errorf("expected output 'done', got %q", output) + } + want := []string{"echo: one", "echo: two"} + if !slices.Equal(chunks, want) { + t.Errorf("expected chunks %v, got %v", want, chunks) + } +} + func TestBidiFlowRegistration(t *testing.T) { r := registry.New() From d20cfe41cef937dfbaa2775c04117239fc67000e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 11 Jun 2026 09:58:40 -0700 Subject: [PATCH 094/141] fix(go/exp): isolate session state and apply stream side effects synchronously In-process callers shared memory with session state at every entry boundary: loadSession installed client-managed state by shallow struct copy, and input messages and streamed artifacts entered the session as raw pointers. AddArtifacts' in-place replace wrote into the caller's array, two invocations given the same state object cross-corrupted, a caller mutating a sent message raced snapshot marshaling, and a mutated artifact changed state without a version bump, defeating the snapshot dedupe. Deep-copy on the way in (loadSession and resumeSessionFrom, the input in Run, the artifact in applySideEffects), clone before the getSnapshot transform as the StateTransform contract already promised, and document the element aliasing in Messages()/Artifacts(); exit boundaries already cloned. Responder sends applied their session side effects asynchronously in the router goroutine, with the TurnEnd rendezvous as the only sync point, so the turn-end snapshot (taken before that rendezvous) and Result() near-deterministically missed artifacts sent in the same turn; with WithSnapshotOn(turnEnd) the artifact was permanently absent from the resume snapshot. Side effects now run in Responder.send before it returns, so any read or snapshot after a Send observes the chunk; the router only forwards to the wire. Update* callbacks run under the session lock and therefore must not send on a Responder (documented). Also fix a pre-existing detach bug this surfaced: the framework cancels the action context right after run returns the detached output, so the pre-detach watcher could see both its wake conditions ready and randomly pick the cancel arm, killing the detached background work it was meant to outlive. The watcher now arbitrates through the same sync.Once as markDetached, making the cancel a no-op once detach has landed. --- go/ai/exp/agent.go | 158 ++++++++++++++++------- go/ai/exp/agent_test.go | 269 ++++++++++++++++++++++++++++++++++++++++ go/ai/exp/option.go | 4 +- go/ai/exp/session.go | 29 ++++- 4 files changed, 405 insertions(+), 55 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 2e18fea8f4..2c10a059e2 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -157,6 +157,12 @@ type TurnResult struct { // than failing the action. func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) (*TurnResult, error)) error { for input := range s.InputCh { + // Deep-copy at the framework boundary: an in-process caller + // retains the pointers it sent (message, resume parts) and may + // mutate them after Send returns, so everything past this point + // (trace marshaling, session state, snapshot writes) must work + // on private memory rather than race the caller. + input = jsonClone(input) spanMeta := &tracing.SpanMetadata{ Name: fmt.Sprintf("agent/turn/%d", s.TurnIndex), Type: "flowStep", @@ -165,7 +171,9 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont _, err := tracing.RunInNewSpan(ctx, spanMeta, input, func(ctx context.Context, input *AgentInput) (any, error) { if input.Message != nil { - s.AddMessages(input.Message) + // The session owns its history: store a copy so fn's + // view of the input stays independent of session state. + s.AddMessages(jsonClone(input.Message)) } tr, err := fn(ctx, input) if err != nil { @@ -227,7 +235,8 @@ func (s *SessionRunner[State]) recordLastGood() { // Result returns an [AgentResult] populated from the current session state: // the last message in the conversation history and all artifacts. The // returned value is independent of the session; callers may mutate it -// without affecting session state. +// without affecting session state. An artifact sent through the +// [Responder] is visible here as soon as the Send call returns. // // It is a convenience for custom agents that don't need to construct the // result manually. @@ -380,17 +389,26 @@ func (s *SessionRunner[State]) recoverySnapshotID(ctx context.Context) string { // --- Responder --- // Responder is the output channel for an agent. Artifacts sent through -// it are automatically added to the session before being forwarded to the -// client. +// it are added to the session synchronously: by the time a Send method +// returns, the chunk's session-level side effects have been applied, so +// a state read ([SessionRunner.Result], [Session.Artifacts]) or a +// turn-end snapshot that follows the call observes them. Only the wire +// forward to the client is asynchronous. // // All Send methods are ctx-aware: if the agent's work context is // cancelled (typically client disconnect, abort during detach, or fn -// completion), Send returns promptly with the chunk dropped. Send itself -// remains fire-and-forget and returns no error; the user fn is expected -// to observe cancellation through its own ctx check and stop producing. +// completion), Send returns promptly with the chunk dropped from the +// wire; the session-level side effects still apply. Send itself remains +// fire-and-forget and returns no error; the user fn is expected to +// observe cancellation through its own ctx check and stop producing. type Responder[Stream any] struct { in chan<- *AgentStreamChunk[Stream] ctx context.Context + // effects applies the chunk's in-process side effects (session + // artifact add, turn-chunk accumulation) synchronously in send, in + // the sender's goroutine, so reads and snapshots that follow a Send + // cannot miss the chunk. + effects func(*AgentStreamChunk[Stream]) } // SendModelChunk sends a generation chunk (token-level streaming). @@ -405,19 +423,30 @@ func (r Responder[Stream]) SendStatus(status Stream) { // SendArtifact sends an artifact to the stream and adds it to the session. // If an artifact with the same name already exists in the session, it is -// replaced. The session-level side effect happens whether or not detach -// has landed; only the wire forward to the client is suppressed -// post-detach, when there is no longer a client to receive it. +// replaced. The artifact is in the session by the time SendArtifact +// returns, and the session stores a deep copy captured at the call, so +// later mutations of the caller's artifact do not affect session state. +// The session-level side effect happens whether or not detach has landed; +// only the wire forward to the client is suppressed post-detach, when +// there is no longer a client to receive it. func (r Responder[Stream]) SendArtifact(artifact *Artifact) { r.send(&AgentStreamChunk[Stream]{Artifact: artifact}) } -// send delivers chunk to the router, returning promptly if r.ctx is -// cancelled. Dropping on cancel decouples fn liveness from the runtime's -// shutdown choreography: a Send issued after workCtx cancellation -// completes immediately rather than blocking on a router that has not -// yet been put into drain mode by a terminal path. +// send applies chunk's in-process side effects, then delivers it to the +// router for the wire forward, returning promptly if r.ctx is cancelled. +// Applying side effects synchronously (in the sender's goroutine, before +// the channel send) orders them before everything the caller does after +// Send, so a state read or a turn-end snapshot cannot miss a chunk whose +// Send already returned. Dropping the wire forward on cancel decouples +// fn liveness from the runtime's shutdown choreography: a Send issued +// after workCtx cancellation completes immediately rather than blocking +// on a router that has not yet been put into drain mode by a terminal +// path. func (r Responder[Stream]) send(chunk *AgentStreamChunk[Stream]) { + if r.effects != nil { + r.effects(chunk) + } select { case r.in <- chunk: case <-r.ctx.Done(): @@ -647,6 +676,11 @@ func newAgentRuntime[Stream, State any]( // snapshot (if applicable), and forwards the resulting [TurnEnd] chunk // through the router so clients see it on the output stream. // +// The snapshot sees everything the turn produced: the side effects of +// the turn's Send calls (e.g. artifacts) are applied synchronously in +// [Responder] before each Send returns, and fn returned before this +// runs, so there is no in-flight router work to wait out. +// // The snapshot is skipped when the turn failed (the live state holds the // turn's partial mutations) and when detach has landed (maybeSnapshot // observes the suspension under snapMu; the pending row already captures @@ -685,7 +719,17 @@ func (rt *agentRuntime[Stream, State]) run( go func() { select { case <-clientCtx.Done(): - cancelWork() + // Arbitrate atomically against markDetached: whichever claims + // the Once first wins. clientCtx ends not only on a true + // disconnect but also when the framework releases the action + // context right after run returns the detached output; by then + // both select arms are ready and this arm may be picked, so + // claiming the Once (rather than cancelling outright) is what + // keeps an already-landed detach's background work alive. + detachOnce.Do(func() { + close(detached) + cancelWork() + }) case <-detached: } }() @@ -754,10 +798,10 @@ func (rt *agentRuntime[Stream, State]) checkDetachCapabilities() error { // to surface its error. func (rt *agentRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFunc) fnDoneResult[State] { cancelWork() - // Switch the router to side-effects-only mode before waiting on fn. - // Without this, a fn mid-SendStatus blocks on the router's r.in - // receive while the router blocks on r.out send (consumer is gone), - // so fn never observes ctx and we deadlock waiting on fnDone. + // Switch the router to discard mode before waiting on fn. Without + // this, a fn mid-SendStatus blocks on the router's r.in receive while + // the router blocks on r.out send (consumer is gone), so fn never + // observes ctx and we deadlock waiting on fnDone. rt.router.stopAndWait() rt.intake.stopAndWait() res := <-rt.fnDone @@ -871,9 +915,9 @@ func (rt *agentRuntime[Stream, State]) failedOutput(ctx context.Context, cause e // status-subscriber and finalizer goroutines that own the rest of the // invocation. Per-turn snapshots are suspended for the remainder so the // queued inputs roll into a single finalize rewrite; the chunk router -// stops writing to outCh but keeps applying in-process side effects -// (e.g. artifacts added via Responder.SendArtifact) so user code does -// not have to branch on detach. +// stops writing to outCh and discards further chunks, whose in-process +// side effects (e.g. artifacts added via Responder.SendArtifact) still +// apply at Send time, so user code does not have to branch on detach. func (rt *agentRuntime[Stream, State]) handleDetach( clientCtx, workCtx context.Context, cancelWork context.CancelFunc, @@ -1048,7 +1092,12 @@ func loadSession[State any]( return nil, nil, core.NewError(core.FAILED_PRECONDITION, "state provided but agent has a session store configured (server-managed state); use snapshot ID instead") } - s.state = *init.State + // Deep-copy at the entry boundary: an in-process caller retains + // its state object ([WithState] documents resending it), so the + // session must own private memory. Without this, AddArtifacts' + // in-place replace writes into the caller's array and the + // caller's later mutations race snapshot marshaling. + s.state = *jsonClone(init.State) return s, nil, nil case init.SnapshotID != "": @@ -1115,7 +1164,12 @@ func resumeSessionFrom[State any](s *Session[State], snap *SessionSnapshot[State "snapshot %q was aborted", snap.SnapshotID) } if snap.State != nil { - s.state = *snap.State + // Stores may return rows sharing memory with their internal + // copies (the [SnapshotReader] contract does not require fresh + // memory), so the session takes a private copy; otherwise two + // invocations resumed from the same snapshot would cross-corrupt + // through the shared backing arrays. + s.state = *jsonClone(snap.State) } return s, snap, nil } @@ -1123,10 +1177,12 @@ func resumeSessionFrom[State any](s *Session[State], snap *SessionSnapshot[State // --- chunkRouter --- // // chunkRouter owns the intermediate stream channel that all chunks flow -// through on their way to outCh. Every chunk gets the same in-process -// side effects (adding artifacts to the session, accumulating turn -// chunks for span output) regardless of whether detach has landed; the -// wire forward to outCh is the only thing detach suppresses, since the +// through on their way to outCh. A chunk's in-process side effects +// (adding artifacts to the session, accumulating turn chunks for span +// output) are applied synchronously by Responder.send before the chunk +// enters the router, so every chunk gets them in its sender's goroutine +// regardless of whether detach has landed; the router owns only the wire +// forward to outCh, which is the one thing detach suppresses, since the // bidi framework closes outCh shortly after bidiFn returns. The router // commits to not writing before we return so that close is safe, and // keeps draining its input so the user fn never blocks on a responder @@ -1172,20 +1228,24 @@ func (r *chunkRouter[Stream, State]) run() { } close(r.writerStopped) // Writes stopped (detach, shutdown, or client disconnect): keep - // applying side effects so the user fn's SendArtifact/SendModelChunk - // calls behave the same way they did before. Only the wire forward to - // outCh is suppressed. - for chunk := range r.in { - r.applySideEffects(chunk) + // draining so a producer mid-send never blocks. The chunks' side + // effects already happened at Send time; only the wire forward to + // outCh is suppressed, so the chunks are simply discarded. + for range r.in { } } // applySideEffects records the chunk's effect on session state and turn -// span output. Invoked from both forward (pre-detach) and the post-detach -// drain so a Send call is observably the same in either mode. +// span output. Invoked synchronously from Responder.send, in the +// sender's goroutine, so the effects are ordered before everything the +// sender does after Send: a state read, a turn-end snapshot, or +// [SessionRunner.Result] immediately after SendArtifact observes the +// artifact. The artifact is deep-copied on its way into the session so +// the sender's retained pointer (which also rides the wire chunk) cannot +// alias live session state. func (r *chunkRouter[Stream, State]) applySideEffects(chunk *AgentStreamChunk[Stream]) { if chunk.Artifact != nil { - r.session.AddArtifacts(chunk.Artifact) + r.session.AddArtifacts(jsonClone(chunk.Artifact)) } if chunk.TurnEnd == nil { r.turnMu.Lock() @@ -1194,10 +1254,9 @@ func (r *chunkRouter[Stream, State]) applySideEffects(chunk *AgentStreamChunk[St } } -// forward delivers chunks to outCh and applies side effects until told to -// stop writing, the action context ends, or r.in closes. Returns true if -// the router must keep draining side effects (writes stopped), false if -// r.in closed. +// forward delivers chunks to outCh until told to stop writing, the +// action context ends, or r.in closes. Returns true if the router must +// keep draining (writes stopped), false if r.in closed. func (r *chunkRouter[Stream, State]) forward() bool { for { select { @@ -1205,7 +1264,6 @@ func (r *chunkRouter[Stream, State]) forward() bool { if !ok { return false } - r.applySideEffects(chunk) select { case r.out <- chunk: case <-r.stopWriting: @@ -1223,16 +1281,20 @@ func (r *chunkRouter[Stream, State]) forward() bool { } } -// responder returns a [Responder] that sends chunks into the router. The -// returned Responder's Send methods drop chunks (returning promptly) -// when ctx is cancelled. +// responder returns a [Responder] that applies chunk side effects +// synchronously and sends chunks into the router for the wire forward. +// The returned Responder's Send methods drop the forward (returning +// promptly) when ctx is cancelled. func (r *chunkRouter[Stream, State]) responder(ctx context.Context) Responder[Stream] { - return Responder[Stream]{in: r.in, ctx: ctx} + return Responder[Stream]{in: r.in, ctx: ctx, effects: r.applySideEffects} } // sendChunk delivers chunk to the router for producers other than the -// user agent function (e.g. the runtime's emitTurnEnd). Returns -// promptly if ctx is cancelled, dropping the chunk. +// user agent function (e.g. the runtime's emitTurnEnd). It skips the +// in-process side effects (the only runtime-produced chunk is TurnEnd, +// which has none: no artifact, and TurnEnd is excluded from turn-chunk +// accumulation) and returns promptly if ctx is cancelled, dropping the +// chunk. func (r *chunkRouter[Stream, State]) sendChunk(ctx context.Context, chunk *AgentStreamChunk[Stream]) { select { case r.in <- chunk: diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 945a376db0..14b8020394 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -414,6 +414,205 @@ func TestAgent_Artifacts(t *testing.T) { } } +// TestAgent_ClientManagedState_CallerStateIsolated verifies that the +// framework deep-copies client-managed state at the entry boundary: the +// invocation must not write into the caller's state object (e.g. via +// AddArtifacts' in-place replace), and two invocations given the same +// state object must not share memory. +func TestAgent_ClientManagedState_CallerStateIsolated(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "stateIsolationFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + // Replace the artifact the caller's state carried (the + // in-place replace path) and extend history. + resp.SendArtifact(&Artifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("v2")}, + }) + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil, nil + }) + }, + ) + + callerArtifact := &Artifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("v1")}, + } + prev := &SessionState[testState]{ + Artifacts: []*Artifact{callerArtifact}, + Messages: []*ai.Message{ai.NewUserTextMessage("previous")}, + } + + out, err := af.RunText(ctx, "turn 1", WithState(prev)) + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + // The caller's state object must be untouched: same artifact pointer + // and content, no appended messages. + if prev.Artifacts[0] != callerArtifact { + t.Error("invocation replaced the artifact inside the caller's state object") + } + if got := callerArtifact.Parts[0].Text; got != "v1" { + t.Errorf("caller's artifact content changed to %q, want %q", got, "v1") + } + if got := len(prev.Messages); got != 1 { + t.Errorf("caller's message slice grew to %d entries, want 1", got) + } + + // The output reflects the replace: one artifact with the new content. + if got := len(out.State.Artifacts); got != 1 { + t.Fatalf("expected 1 artifact in output state, got %d", got) + } + if got := out.State.Artifacts[0].Parts[0].Text; got != "v2" { + t.Errorf("output artifact content = %q, want %q", got, "v2") + } +} + +// TestAgent_InputMessageCloned verifies the session stores a private copy +// of the input message: a caller mutating the message it sent after the +// turn must not change conversation history. +func TestAgent_InputMessageCloned(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "inputCloneFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + sent := ai.NewUserTextMessage("original") + if err := conn.SendMessage(sent); err != nil { + t.Fatalf("SendMessage failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.TurnEnd != nil { + break + } + } + + // The turn is over, so the message is in session history. Mutating + // the caller's copy must not reach it. + sent.Content[0].Text = "mutated" + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + if got := len(out.State.Messages); got != 1 { + t.Fatalf("expected 1 message, got %d", got) + } + if got := out.State.Messages[0].Content[0].Text; got != "original" { + t.Errorf("session history reflects caller's mutation: got %q, want %q", got, "original") + } +} + +// TestAgent_SendArtifact_SynchronousAndCloned verifies SendArtifact's two +// guarantees: the artifact is visible to session reads by the time the +// call returns (Result right after SendArtifact must include it), and the +// session holds a private copy, so the sender's retained pointer cannot +// mutate session state. +func TestAgent_SendArtifact_SynchronousAndCloned(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + var ( + resultArtifacts int + sessionContent string + ) + af := DefineCustomAgent(reg, "syncArtifactFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + a := &Artifact{Name: "out.txt", Parts: []*ai.Part{ai.NewTextPart("original")}} + resp.SendArtifact(a) + // Visible as soon as SendArtifact returns. + resultArtifacts = len(sess.Result().Artifacts) + // The session holds its own copy. + a.Parts[0] = ai.NewTextPart("mutated") + if arts := sess.Artifacts(); len(arts) == 1 { + sessionContent = arts[0].Parts[0].Text + } + return nil, nil + }) + }, + ) + + if _, err := af.RunText(ctx, "go"); err != nil { + t.Fatalf("RunText failed: %v", err) + } + if resultArtifacts != 1 { + t.Errorf("Result() right after SendArtifact saw %d artifacts, want 1", resultArtifacts) + } + if sessionContent != "original" { + t.Errorf("session artifact content = %q, want %q (sender's mutation must not reach session state)", sessionContent, "original") + } +} + +// TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact verifies that a +// turn-end snapshot captures artifacts sent during that turn: the Send +// side effect applies before the call returns, so the snapshot taken at +// turn end cannot miss it. With snapshots restricted to turn end, the +// invocation output reuses the turn-end row, which therefore must hold +// the artifact for a later resume. +func TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := DefineCustomAgent(reg, "turnEndArtifactFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + resp.SendArtifact(&Artifact{ + Name: "report.md", + Parts: []*ai.Part{ai.NewTextPart("# Report")}, + }) + return nil, nil + }) + }, + WithSessionStore[testState](store), + WithSnapshotOn[testState](SnapshotEventTurnEnd), + ) + + out, err := af.RunText(ctx, "produce the report") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected a snapshot ID on the output") + } + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap == nil { + t.Fatalf("snapshot %q not found", out.SnapshotID) + } + if snap.Event != SnapshotEventTurnEnd { + t.Errorf("snapshot event = %q, want %q", snap.Event, SnapshotEventTurnEnd) + } + if snap.State == nil || len(snap.State.Artifacts) != 1 { + t.Fatalf("turn-end snapshot missing the artifact sent during the turn: %+v", snap.State) + } + if got := snap.State.Artifacts[0].Name; got != "report.md" { + t.Errorf("snapshot artifact name = %q, want %q", got, "report.md") + } +} + func TestAgent_SnapshotCallback(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) @@ -4186,6 +4385,76 @@ func TestPromptAgent_ForwardsFinishReason(t *testing.T) { } } +// TestAgent_Detach_BackgroundWorkSurvivesActionReturn verifies that the +// detached background work's context stays alive after the invocation +// returns the detached output and the framework releases the action +// context. The regression this guards: the pre-detach watcher that +// mirrors the client context could observe both its wake conditions +// (client context released, detach landed) at once and randomly pick the +// cancel arm, killing the background work. The race is scheduler-driven, +// so the test stacks iterations with a settle window; under +// single-threaded scheduling (GOMAXPROCS=1) the regression trips within +// a few iterations. +func TestAgent_Detach_BackgroundWorkSurvivesActionReturn(t *testing.T) { + ctx := context.Background() + for i := 0; i < 20; i++ { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + release := make(chan struct{}) + fnSaw := make(chan string, 1) + + af := DefineCustomAgent(reg, "detachSurviveFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + select { + case <-release: + fnSaw <- "release" + case <-ctx.Done(): + fnSaw <- "ctx" + } + return nil, nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("iteration %d: StreamBidi: %v", i, err) + } + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() + if err := conn.SendText("go"); err != nil { + t.Fatalf("iteration %d: SendText: %v", i, err) + } + if err := conn.Detach(); err != nil { + t.Fatalf("iteration %d: Detach: %v", i, err) + } + out, err := conn.Output() + if err != nil { + t.Fatalf("iteration %d: Output: %v", i, err) + } + + // The action has returned, so the framework's release of the + // action context is imminent or done. Give a wrongly-cancelling + // watcher time to land before letting the turn proceed. + time.Sleep(20 * time.Millisecond) + close(release) + if saw := <-fnSaw; saw != "release" { + t.Fatalf("iteration %d: detached background work saw its context cancelled", i) + } + // Wait out the finalizer so the iteration's goroutines wind down. + waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusSucceeded + }) + } +} + // TestAgent_Detach_FinishReasons covers the three detach outcomes: the output // returned to the detaching client always reports "detached", while the // persisted snapshot records how the background work actually ended diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 1fe4958e61..399c3f5735 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -166,7 +166,9 @@ func (o *invocationOptions[State]) applyInvocation(opts *invocationOptions[State // ([SessionState.SessionID]): the framework mints one on the // conversation's first invocation and echoes it on the output state, so // resending the state keeps the identity without tracking a separate -// field. Mutually exclusive with [WithSessionID] and [WithSnapshotID]. +// field. The framework deep-copies the state when the invocation starts, +// so the caller keeps ownership of the object it passed and may reuse it +// freely. Mutually exclusive with [WithSessionID] and [WithSnapshotID]. func WithState[State any](state *SessionState[State]) InvocationOption[State] { return &invocationOptions[State]{state: state} } diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 51d7390ba9..f426f02815 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -267,7 +267,12 @@ func registerSnapshotActions[State any]( if resp.UpdatedAt.IsZero() { resp.UpdatedAt = resp.CreatedAt } - resp.State = applyTransform(ctx, transform, snap.State) + // Clone before transforming: the [StateTransform] contract + // promises a fresh deep copy the transform may mutate in + // place, and the store's row may share memory with its + // internal copy, which neither the transform nor the + // SessionID re-stamp below may write into. + resp.State = applyTransform(ctx, transform, jsonClone(snap.State)) if resp.State != nil { // SessionID is framework identity, not user data: re-stamp // it from the row after the transform so outbound state @@ -337,7 +342,10 @@ func (s *Session[State]) State() *SessionState[State] { return &copied } -// Messages returns the current conversation history. +// Messages returns the current conversation history. The returned slice +// is a fresh copy, but its elements point at the live messages held by +// the session: treat them as read-only, or deep-copy before mutating. +// [Session.State] returns a fully independent copy. func (s *Session[State]) Messages() []*ai.Message { s.mu.RLock() defer s.mu.RUnlock() @@ -363,7 +371,9 @@ func (s *Session[State]) SetMessages(messages []*ai.Message) { } // UpdateMessages atomically reads the current messages, applies the given -// function, and writes the result back. +// function, and writes the result back. fn runs while the session's +// internal lock is held: it must not call other Session methods or send +// on a [Responder], or it will deadlock. func (s *Session[State]) UpdateMessages(fn func([]*ai.Message) []*ai.Message) { s.mu.Lock() defer s.mu.Unlock() @@ -379,7 +389,9 @@ func (s *Session[State]) Custom() State { } // UpdateCustom atomically reads the current custom state, applies the given -// function, and writes the result back. +// function, and writes the result back. fn runs while the session's +// internal lock is held: it must not call other Session methods or send +// on a [Responder], or it will deadlock. func (s *Session[State]) UpdateCustom(fn func(State) State) { s.mu.Lock() defer s.mu.Unlock() @@ -387,7 +399,10 @@ func (s *Session[State]) UpdateCustom(fn func(State) State) { s.version++ } -// Artifacts returns the current artifacts. +// Artifacts returns the current artifacts. The returned slice is a fresh +// copy, but its elements point at the live artifacts held by the +// session: treat them as read-only, or deep-copy before mutating. +// [Session.State] returns a fully independent copy. func (s *Session[State]) Artifacts() []*Artifact { s.mu.RLock() defer s.mu.RUnlock() @@ -420,7 +435,9 @@ func (s *Session[State]) AddArtifacts(artifacts ...*Artifact) { } // UpdateArtifacts atomically reads the current artifacts, applies the given -// function, and writes the result back. +// function, and writes the result back. fn runs while the session's +// internal lock is held: it must not call other Session methods or send +// on a [Responder], or it will deadlock. func (s *Session[State]) UpdateArtifacts(fn func([]*Artifact) []*Artifact) { s.mu.Lock() defer s.mu.Unlock() From 011c38fa71e02be6aa3ff11fab7b44a5c7eb9d43 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 10:32:10 -0700 Subject: [PATCH 095/141] feat(go): rework bidi actions around a typed BidiConnection and JSON transports Adds core.NewBidiAction/DefineBidiAction with a typed BidiConnection (Send/Close/Receive/Output/Cancel/Done), JSON transports via api.BidiAction and api.BidiJSONConnection, and wiring through the reflection servers (V2 streamInput sessions, V1 one-shot init) and the HTTP handler (init field). - One execution engine: the unary one-shot path and live sessions both run through BidiConnection.run with shared panic recovery and framework-owned channel close protection. - JSON chunks are normalized and validated against a schema compiled once per session; an invalid chunk fails the session through the connection's cancel cause, matching the JS runtime and the unary input validation path. - Session output is validated against OutputSchema; init is decoded, validated against InitSchema, and recorded as the genkit:init span attribute. - Reflection V2 keeps an ownership-checked session registry keyed by request id so reconnects with reused ids cannot tear down or leak newer sessions; the telemetry callback lifecycle tolerates spans that start after the handler responded. - Restores genkit.NewFlow/NewStreamingFlow (released in go/v1.8.0). --- genkit-tools/genkit-schema.json | 4 + go/ai/embedder.go | 4 +- go/ai/evaluator.go | 4 +- go/ai/generate.go | 6 +- go/ai/prompt.go | 4 +- go/ai/resource.go | 6 +- go/ai/retriever.go | 4 +- go/core/action.go | 515 ++++------ go/core/action_test.go | 8 +- go/core/api/action.go | 61 +- go/core/background_action.go | 14 +- go/core/bidi.go | 696 ++++++++++++++ go/core/bidi_test.go | 1154 +++++++++++++++++++++++ go/core/flow.go | 60 +- go/core/flow_test.go | 306 +----- go/core/schemas.config | 6 +- go/core/tracing/tracing.go | 11 + go/core/tracing/tracing_test.go | 22 + go/genkit/gen.go | 11 +- go/genkit/genkit.go | 61 +- go/genkit/reflection.go | 24 +- go/genkit/reflection_test.go | 86 +- go/genkit/reflection_v2.go | 437 ++++++++- go/genkit/reflection_v2_test.go | 494 ++++++++++ go/genkit/servers.go | 27 +- go/genkit/servers_test.go | 125 +++ go/genkit/x/genkit.go | 2 +- go/internal/base/json.go | 8 + go/internal/base/json_type_converter.go | 15 +- go/internal/base/validation.go | 61 +- 30 files changed, 3400 insertions(+), 836 deletions(-) create mode 100644 go/core/bidi.go create mode 100644 go/core/bidi_test.go diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 68262b75f5..905bca657e 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -1658,6 +1658,9 @@ "input": { "description": "An input with the type that this action expects." }, + "init": { + "description": "Initialization parameters to establish long running session states." + }, "context": { "description": "Additional runtime context data (ex. auth context data)." }, @@ -1961,6 +1964,7 @@ }, "input": {}, "output": {}, + "init": {}, "isRoot": { "type": "boolean" }, diff --git a/go/ai/embedder.go b/go/ai/embedder.go index d93d5b52fa..5801da9592 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -85,7 +85,7 @@ type EmbedderOptions struct { // embedder is an action with functions specific to converting documents to multidimensional vectors such as Embed(). type embedder struct { - core.Action[*EmbedRequest, *EmbedResponse, struct{}, struct{}] + core.Action[*EmbedRequest, *EmbedResponse, struct{}] } // NewEmbedder creates a new [Embedder]. @@ -143,7 +143,7 @@ func DefineEmbedder(r api.Registry, name string, opts *EmbedderOptions, fn Embed // It will try to resolve the embedder dynamically if the embedder is not found. // It returns nil if the embedder was not resolved. func LookupEmbedder(r api.Registry, name string) Embedder { - action := core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}, struct{}](r, api.ActionTypeEmbedder, name) + action := core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, api.ActionTypeEmbedder, name) if action == nil { return nil } diff --git a/go/ai/evaluator.go b/go/ai/evaluator.go index dd79a511ba..1ab7335932 100644 --- a/go/ai/evaluator.go +++ b/go/ai/evaluator.go @@ -72,7 +72,7 @@ func (e EvaluatorRef) Config() any { // evaluator is an action with functions specific to evaluating a dataset. type evaluator struct { - core.Action[*EvaluatorRequest, *EvaluatorResponse, struct{}, struct{}] + core.Action[*EvaluatorRequest, *EvaluatorResponse, struct{}] } // Example is a single example that requires evaluation @@ -291,7 +291,7 @@ func DefineBatchEvaluator(r api.Registry, name string, opts *EvaluatorOptions, f // LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator]. // It returns nil if the evaluator was not defined. func LookupEvaluator(r api.Registry, name string) Evaluator { - action := core.ResolveActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}, struct{}](r, api.ActionTypeEvaluator, name) + action := core.ResolveActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, api.ActionTypeEvaluator, name) if action == nil { return nil } diff --git a/go/ai/generate.go b/go/ai/generate.go index 2bd86f4232..3dcd0d9053 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -78,11 +78,11 @@ type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResp // model is an action with functions specific to model generation such as Generate(). type model struct { - core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk, struct{}] + core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk] } // generateAction is the type for a utility model generation action that takes in a GenerateActionOptions instead of a ModelRequest. -type generateAction = core.Action[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk, struct{}] +type generateAction = core.Action[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk] // result is a generic struct for parallel operation results with index, value, and error. type result[T any] struct { @@ -198,7 +198,7 @@ func DefineModel(r api.Registry, name string, opts *ModelOptions, fn ModelFunc) // It will try to resolve the model dynamically if the model is not found. // It returns nil if the model was not resolved. func LookupModel(r api.Registry, name string) Model { - action := core.ResolveActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk, struct{}](r, api.ActionTypeModel, name) + action := core.ResolveActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, api.ActionTypeModel, name) if action == nil { return nil } diff --git a/go/ai/prompt.go b/go/ai/prompt.go index c1928db77d..a73cb4a7e7 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -52,7 +52,7 @@ type Prompt interface { // prompt is a prompt template that can be executed to generate a model response. type prompt struct { - core.Action[any, *GenerateActionOptions, struct{}, struct{}] + core.Action[any, *GenerateActionOptions, struct{}] promptOptions registry api.Registry } @@ -141,7 +141,7 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { // LookupPrompt looks up a [Prompt] registered by [DefinePrompt]. // It returns nil if the prompt was not defined. func LookupPrompt(r api.Registry, name string) Prompt { - action := core.ResolveActionFor[any, *GenerateActionOptions, struct{}, struct{}](r, api.ActionTypeExecutablePrompt, name) + action := core.ResolveActionFor[any, *GenerateActionOptions, struct{}](r, api.ActionTypeExecutablePrompt, name) if action == nil { return nil } diff --git a/go/ai/resource.go b/go/ai/resource.go index 18e1823019..101ac68aa4 100644 --- a/go/ai/resource.go +++ b/go/ai/resource.go @@ -109,7 +109,7 @@ type ResourceFunc = func(context.Context, *ResourceInput) (*ResourceOutput, erro // It holds the underlying core action and allows looking up resources // by name without knowing their specific input/output api. type resource struct { - core.Action[*ResourceInput, *ResourceOutput, struct{}, struct{}] + core.Action[*ResourceInput, *ResourceOutput, struct{}] } // Resource represents an instance of a resource. @@ -227,7 +227,7 @@ func (r *resource) Execute(ctx context.Context, input *ResourceInput) (*Resource // FindMatchingResource finds a resource that matches the given URI. func FindMatchingResource(r api.Registry, uri string) (Resource, *ResourceInput, error) { for _, a := range r.ListActions() { - if action, ok := a.(*core.Action[*ResourceInput, *ResourceOutput, struct{}, struct{}]); ok { + if action, ok := a.(*core.Action[*ResourceInput, *ResourceOutput, struct{}]); ok { res := &resource{Action: *action} if res.Matches(uri) { variables, err := res.ExtractVariables(uri) @@ -244,7 +244,7 @@ func FindMatchingResource(r api.Registry, uri string) (Resource, *ResourceInput, // LookupResource looks up the resource in the registry by provided name and returns it. func LookupResource(r api.Registry, name string) Resource { - action := core.ResolveActionFor[*ResourceInput, *ResourceOutput, struct{}, struct{}](r, api.ActionTypeResource, name) + action := core.ResolveActionFor[*ResourceInput, *ResourceOutput, struct{}](r, api.ActionTypeResource, name) if action == nil { return nil } diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 9ab97f17ce..392fbc41c2 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -40,7 +40,7 @@ type Retriever interface { // retriever is an action with functions specific to document retrieval such as Retrieve(). type retriever struct { - core.Action[*RetrieverRequest, *RetrieverResponse, struct{}, struct{}] + core.Action[*RetrieverRequest, *RetrieverResponse, struct{}] } // RetrieverArg is the interface for retriever arguments. It can either be the retriever action itself or a reference to be looked up. @@ -136,7 +136,7 @@ func DefineRetriever(r api.Registry, name string, opts *RetrieverOptions, fn Ret // It will try to resolve the retriever dynamically if the retriever is not found. // It returns nil if the retriever was not resolved. func LookupRetriever(r api.Registry, name string) Retriever { - action := core.ResolveActionFor[*RetrieverRequest, *RetrieverResponse, struct{}, struct{}](r, api.ActionTypeRetriever, name) + action := core.ResolveActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](r, api.ActionTypeRetriever, name) if action == nil { return nil } diff --git a/go/core/action.go b/go/core/action.go index 576b78b02a..f6ca53e8ed 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -19,9 +19,7 @@ package core import ( "context" "encoding/json" - "iter" "reflect" - "sync" "time" "github.com/firebase/genkit/go/core/api" @@ -34,22 +32,16 @@ import ( // Func is an alias for non-streaming functions with input of type In and output of type Out. type Func[In, Out any] = func(context.Context, In) (Out, error) -// StreamingFunc is an alias for streaming functions with input of type In, output of type Out, and outgoing stream chunk of type StreamOut. -type StreamingFunc[In, Out, StreamOut any] = func(context.Context, In, StreamCallback[StreamOut]) (Out, error) +// StreamingFunc is an alias for streaming functions with input of type In, output of type Out, and outgoing stream chunk of type Stream. +type StreamingFunc[In, Out, Stream any] = func(context.Context, In, StreamCallback[Stream]) (Out, error) // StreamCallback is a function that is called during streaming to return the next chunk of the outgoing stream. -type StreamCallback[StreamOut any] = func(context.Context, StreamOut) error - -// BidiFunc is the function signature for bidirectional streaming actions. -// It receives an initial input, reads incoming stream messages from inCh, -// and writes outgoing stream messages to outCh. It returns a final output when complete. -type BidiFunc[In, Out, StreamOut, StreamIn any] = func(ctx context.Context, in In, inCh <-chan StreamIn, outCh chan<- StreamOut) (Out, error) +type StreamCallback[Stream any] = func(context.Context, Stream) error // An Action is a named, observable operation that underlies all Genkit primitives. // It consists of a function that takes an input of type In and returns an output -// of type Out, optionally streaming values of type StreamOut incrementally by -// invoking a callback. For bidirectional actions, StreamIn is the type of -// incoming stream messages. +// of type Out, optionally streaming values of type Stream incrementally by +// invoking a callback. // // It optionally has other metadata, like a description and JSON Schemas for its input and // output which it validates against. @@ -57,13 +49,17 @@ type BidiFunc[In, Out, StreamOut, StreamIn any] = func(ctx context.Context, in I // Each time an Action is run, it results in a new trace span. // // For internal use only. -type Action[In, Out, StreamOut, StreamIn any] struct { - fn StreamingFunc[In, Out, StreamOut] // Function that is called during runtime. May not actually support streaming. - bidiFn BidiFunc[In, Out, StreamOut, StreamIn] // Non-nil for bidi actions only. - desc *api.ActionDesc // Descriptor of the action. - registry api.Registry // Registry for schema resolution. Set when registered. +type Action[In, Out, Stream any] struct { + fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. + desc *api.ActionDesc // Descriptor of the action. + registry api.Registry // Registry for schema resolution. Set when registered. } +// ActionDef is the previous name for [Action]. +// +// Deprecated: use [Action]. +type ActionDef[In, Out, Stream any] = Action[In, Out, Stream] + type noStream = func(context.Context, struct{}) error // NewAction creates a new non-streaming [Action] without registering it. @@ -74,8 +70,8 @@ func NewAction[In, Out any]( metadata map[string]any, inputSchema map[string]any, fn Func[In, Out], -) *Action[In, Out, struct{}, struct{}] { - return newAction[In, Out, struct{}, struct{}](name, atype, metadata, inputSchema, +) *Action[In, Out, struct{}] { + return newStreamingAction(name, atype, metadata, inputSchema, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) @@ -83,75 +79,14 @@ func NewAction[In, Out any]( // NewStreamingAction creates a new streaming [Action] without registering it. // If inputSchema is nil, it is inferred from the function's input api. -func NewStreamingAction[In, Out, StreamOut any]( +func NewStreamingAction[In, Out, Stream any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, StreamOut], -) *Action[In, Out, StreamOut, struct{}] { - return newAction[In, Out, StreamOut, struct{}](name, atype, metadata, inputSchema, fn) -} - -// ActionOptions configures a bidi action. Nil schema fields are inferred from type parameters. -type ActionOptions struct { - Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. - InputSchema map[string]any // JSON schema for the action's input. Inferred from In if nil. - OutputSchema map[string]any // JSON schema for the action's output. Inferred from Out if nil. - StreamOutSchema map[string]any // JSON schema for outgoing streamed chunks. Inferred from StreamOut if nil. Not used for non-streaming actions. - StreamInSchema map[string]any // JSON schema for incoming stream messages. Inferred from StreamIn if nil. Not used for non-bidi actions. -} - -// NewBidiAction creates a new bidirectional streaming [Action] without registering it. -func NewBidiAction[In, Out, StreamOut, StreamIn any]( - name string, - atype api.ActionType, - opts *ActionOptions, - fn BidiFunc[In, Out, StreamOut, StreamIn], -) *Action[In, Out, StreamOut, StreamIn] { - if opts == nil { - opts = &ActionOptions{} - } - - metadata := opts.Metadata - if metadata == nil { - metadata = map[string]any{} - } - metadata["bidi"] = true - - a := newAction[In, Out, StreamOut, StreamIn](name, atype, metadata, opts.InputSchema, wrapBidiAsStreaming(fn)) - a.bidiFn = fn - - if opts.OutputSchema != nil { - a.desc.OutputSchema = opts.OutputSchema - } - if opts.StreamOutSchema != nil { - a.desc.StreamSchema = opts.StreamOutSchema - } - - if opts.StreamInSchema != nil { - a.desc.InitSchema = opts.StreamInSchema - } else { - var inStream StreamIn - if reflect.ValueOf(inStream).Kind() != reflect.Invalid { - a.desc.InitSchema = InferSchemaMap(inStream) - } - } - - return a -} - -// DefineBidiAction creates and registers a bidirectional streaming [Action]. -func DefineBidiAction[In, Out, StreamOut, StreamIn any]( - r api.Registry, - name string, - atype api.ActionType, - opts *ActionOptions, - fn BidiFunc[In, Out, StreamOut, StreamIn], -) *Action[In, Out, StreamOut, StreamIn] { - a := NewBidiAction(name, atype, opts, fn) - a.Register(r) - return a + fn StreamingFunc[In, Out, Stream], +) *Action[In, Out, Stream] { + return newStreamingAction(name, atype, metadata, inputSchema, fn) } // DefineAction creates a new non-streaming Action and registers it. @@ -163,8 +98,8 @@ func DefineAction[In, Out any]( metadata map[string]any, inputSchema map[string]any, fn Func[In, Out], -) *Action[In, Out, struct{}, struct{}] { - return defineAction[In, Out, struct{}, struct{}](r, name, atype, metadata, inputSchema, +) *Action[In, Out, struct{}] { + return defineStreamingAction(r, name, atype, metadata, inputSchema, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) @@ -172,58 +107,65 @@ func DefineAction[In, Out any]( // DefineStreamingAction creates a new streaming action and registers it. // If inputSchema is nil, it is inferred from the function's input api. -func DefineStreamingAction[In, Out, StreamOut any]( +func DefineStreamingAction[In, Out, Stream any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, StreamOut], -) *Action[In, Out, StreamOut, struct{}] { - return defineAction[In, Out, StreamOut, struct{}](r, name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, Stream], +) *Action[In, Out, Stream] { + return defineStreamingAction(r, name, atype, metadata, inputSchema, fn) } -// defineAction creates an action and registers it with the given Registry. -func defineAction[In, Out, StreamOut, StreamIn any]( +// defineStreamingAction creates a streaming action and registers it. +func defineStreamingAction[In, Out, Stream any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, StreamOut], -) *Action[In, Out, StreamOut, StreamIn] { - a := newAction[In, Out, StreamOut, StreamIn](name, atype, metadata, inputSchema, fn) + fn StreamingFunc[In, Out, Stream], +) *Action[In, Out, Stream] { + a := newStreamingAction(name, atype, metadata, inputSchema, fn) a.Register(r) return a } -// newAction creates a new Action with the given name and arguments. -// If registry is nil, tracing state is left nil to be set later. -// If inputSchema is nil, it is inferred from In. -func newAction[In, Out, StreamOut, StreamIn any]( +// newStreamingAction constructs an action with the given implementation. +// It is the common helper for NewAction, NewStreamingAction, and +// DefineStreamingAction. +func newStreamingAction[In, Out, Stream any]( + name string, + atype api.ActionType, + metadata map[string]any, + inputSchema map[string]any, + fn StreamingFunc[In, Out, Stream], +) *Action[In, Out, Stream] { + a := newAction[In, Out, Stream](name, atype, metadata, inputSchema) + a.fn = fn + return a +} + +// newAction populates an Action's descriptor with inferred schemas and metadata. +// The caller is expected to assign a.fn. +func newAction[In, Out, Stream any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, - fn StreamingFunc[In, Out, StreamOut], -) *Action[In, Out, StreamOut, StreamIn] { +) *Action[In, Out, Stream] { if inputSchema == nil { - var i In - if reflect.ValueOf(i).Kind() != reflect.Invalid { - inputSchema = InferSchemaMap(i) - } + inputSchema = inferSchema[In]() } - var o Out - var outputSchema map[string]any - if reflect.ValueOf(o).Kind() != reflect.Invalid { - outputSchema = InferSchemaMap(o) - } + outputSchema := inferSchema[Out]() - var s StreamOut - var outStreamSchema map[string]any - if reflect.ValueOf(s).Kind() != reflect.Invalid { - outStreamSchema = InferSchemaMap(s) + // Stream is struct{} for non-streaming actions; inferring a schema from + // the sentinel would make every action advertise a bogus streamSchema. + var streamSchema map[string]any + if !isUnitType[Stream]() { + streamSchema = inferSchema[Stream]() } var description string @@ -231,10 +173,7 @@ func newAction[In, Out, StreamOut, StreamIn any]( description = desc } - return &Action[In, Out, StreamOut, StreamIn]{ - fn: func(ctx context.Context, input In, cb StreamCallback[StreamOut]) (Out, error) { - return fn(ctx, input, cb) - }, + return &Action[In, Out, Stream]{ desc: &api.ActionDesc{ Type: atype, Key: api.KeyFromName(atype, name), @@ -242,26 +181,47 @@ func newAction[In, Out, StreamOut, StreamIn any]( Description: description, InputSchema: inputSchema, OutputSchema: outputSchema, - StreamSchema: outStreamSchema, + StreamSchema: streamSchema, Metadata: metadata, }, } } +// isUnitType reports whether T is exactly struct{}, the sentinel type +// parameter meaning "no value" (no stream chunks, no init). Named empty +// struct types do not match and are treated as real types. +func isUnitType[T any]() bool { + return reflect.TypeFor[T]() == reflect.TypeFor[struct{}]() +} + +// inferSchema returns the JSON schema inferred from T's zero value, or nil +// for interface types, whose zero value carries no type information to infer +// from. +func inferSchema[T any]() map[string]any { + var v T + if reflect.ValueOf(v).Kind() == reflect.Invalid { + return nil + } + return InferSchemaMap(v) +} + // Name returns the Action's Name. -func (a *Action[In, Out, StreamOut, StreamIn]) Name() string { return a.desc.Name } +func (a *Action[In, Out, Stream]) Name() string { return a.desc.Name } // Run executes the Action's function in a new trace span. -func (a *Action[In, Out, StreamOut, StreamIn]) Run(ctx context.Context, input In, cb StreamCallback[StreamOut]) (output Out, err error) { - r, err := a.runWithTelemetry(ctx, input, cb) +func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { + r, err := a.runWithTelemetry(ctx, input, cb, a.fn, nil) if err != nil { return base.Zero[Out](), err } return r.Result, nil } -// runWithTelemetry executes the Action's function in a new trace span and returns telemetry info. -func (a *Action[In, Out, StreamOut, StreamIn]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[StreamOut]) (output api.ActionRunResult[Out], err error) { +// runWithTelemetry executes fn in a new trace span and returns telemetry +// info. fn is a parameter (rather than always a.fn) so that BidiAction can +// inject a per-call one-shot adapter; spanInit, when non-nil, is recorded as +// the span's genkit:init attribute. +func (a *Action[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream], fn StreamingFunc[In, Out, Stream], spanInit any) (output api.ActionRunResult[Out], err error) { logger.FromContext(ctx).Debug("Action.Run", "name", a.Name()) defer func() { logger.FromContext(ctx).Debug("Action.Run", @@ -269,37 +229,16 @@ func (a *Action[In, Out, StreamOut, StreamIn]) runWithTelemetry(ctx context.Cont "err", err) }() - // Create span metadata and inject flow name if we're in a flow context - spanMetadata := &tracing.SpanMetadata{ - Name: a.desc.Name, - Type: "action", - Subtype: string(a.desc.Type), // The actual action type becomes the subtype - Metadata: make(map[string]string), - // IsRoot will be automatically determined in tracing.go based on parent span presence - } - - // Auto-inject flow name if we're in a flow context - if flowName := FlowNameFromContext(ctx); flowName != "" { - spanMetadata.Metadata["flow:name"] = flowName - } - var traceID string var spanID string - o, err := tracing.RunInNewSpan(ctx, spanMetadata, input, + o, err := tracing.RunInNewSpan(ctx, a.spanMetadata(ctx, spanInit), input, func(ctx context.Context, input In) (output Out, err error) { traceInfo := tracing.SpanTraceInfo(ctx) traceID = traceInfo.TraceID spanID = traceInfo.SpanID start := time.Now() - defer func() { - latency := time.Since(start) - if err != nil { - metrics.WriteActionFailure(ctx, a.desc.Name, latency, err) - } else { - metrics.WriteActionSuccess(ctx, a.desc.Name, latency) - } - }() + defer func() { recordActionMetrics(ctx, a.desc.Name, start, err) }() var inputSchema map[string]any inputSchema, err = ResolveSchema(a.registry, a.desc.InputSchema) @@ -308,24 +247,20 @@ func (a *Action[In, Out, StreamOut, StreamIn]) runWithTelemetry(ctx context.Cont } var outputSchema map[string]any - outputSchema, err = ResolveSchema(a.registry, a.desc.OutputSchema) + outputSchema, err = a.resolveOutputSchema() if err != nil { - return base.Zero[Out](), NewError(INVALID_ARGUMENT, "invalid output schema for action %q: %v", a.desc.Key, err) + return base.Zero[Out](), err } if err = base.ValidateValue(input, inputSchema); err != nil { return base.Zero[Out](), NewError(INVALID_ARGUMENT, "invalid input to action %q: %v", a.desc.Key, err) } - output, err = a.fn(ctx, input, cb) + output, err = fn(ctx, input, cb) if err != nil { return output, err } - if err = base.ValidateValue(output, outputSchema); err != nil { - err = NewError(INTERNAL, "invalid output from action %q: %v", a.desc.Key, err) - } - - return output, err + return output, a.validateOutput(output, outputSchema) }, ) @@ -336,8 +271,55 @@ func (a *Action[In, Out, StreamOut, StreamIn]) runWithTelemetry(ctx context.Cont }, err } +// spanMetadata builds the trace span metadata for one run of this action, +// injecting the flow name when ctx carries one. spanInit, when non-nil, is +// recorded as the span's genkit:init attribute. IsRoot is determined later by +// the tracing package from parent span presence. +func (a *Action[In, Out, Stream]) spanMetadata(ctx context.Context, spanInit any) *tracing.SpanMetadata { + sm := &tracing.SpanMetadata{ + Name: a.desc.Name, + Type: "action", + Subtype: string(a.desc.Type), // The actual action type becomes the subtype. + Metadata: make(map[string]string), + Init: spanInit, + } + if flowName := FlowNameFromContext(ctx); flowName != "" { + sm.Metadata["flow:name"] = flowName + } + return sm +} + +// recordActionMetrics writes the success/failure metric for one action run. +func recordActionMetrics(ctx context.Context, name string, start time.Time, err error) { + latency := time.Since(start) + if err != nil { + metrics.WriteActionFailure(ctx, name, latency, err) + } else { + metrics.WriteActionSuccess(ctx, name, latency) + } +} + +// resolveOutputSchema resolves the action's OutputSchema $refs through the +// registry. +func (a *Action[In, Out, Stream]) resolveOutputSchema() (map[string]any, error) { + schema, err := ResolveSchema(a.registry, a.desc.OutputSchema) + if err != nil { + return nil, NewError(INVALID_ARGUMENT, "invalid output schema for action %q: %v", a.desc.Key, err) + } + return schema, nil +} + +// validateOutput checks a final output value against the resolved output +// schema. +func (a *Action[In, Out, Stream]) validateOutput(out Out, schema map[string]any) error { + if err := base.ValidateValue(out, schema); err != nil { + return NewError(INTERNAL, "invalid output from action %q: %v", a.desc.Key, err) + } + return nil +} + // RunJSON runs the action with a JSON input, and returns a JSON result. -func (a *Action[In, Out, StreamOut, StreamIn]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { +func (a *Action[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { r, err := a.RunJSONWithTelemetry(ctx, input, cb) if err != nil { return nil, err @@ -346,15 +328,21 @@ func (a *Action[In, Out, StreamOut, StreamIn]) RunJSON(ctx context.Context, inpu } // RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info. -func (a *Action[In, Out, StreamOut, StreamIn]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Action[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { + return a.runJSONWithTelemetry(ctx, input, cb, a.fn, nil) +} + +// runJSONWithTelemetry is the shared JSON execution path. fn and spanInit +// follow the same contract as runWithTelemetry. +func (a *Action[In, Out, Stream]) runJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], fn StreamingFunc[In, Out, Stream], spanInit any) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { return nil, NewError(INVALID_ARGUMENT, err.Error()) } - var scb StreamCallback[StreamOut] + var scb StreamCallback[Stream] if cb != nil { - scb = func(ctx context.Context, s StreamOut) error { + scb = func(ctx context.Context, s Stream) error { bytes, err := json.Marshal(s) if err != nil { return err @@ -363,7 +351,7 @@ func (a *Action[In, Out, StreamOut, StreamIn]) RunJSONWithTelemetry(ctx context. } } - r, err := a.runWithTelemetry(ctx, i, scb) + r, err := a.runWithTelemetry(ctx, i, scb, fn, spanInit) if err != nil { return &api.ActionRunResult[json.RawMessage]{ TraceId: r.TraceId, @@ -386,88 +374,37 @@ func (a *Action[In, Out, StreamOut, StreamIn]) RunJSONWithTelemetry(ctx context. // Desc returns a descriptor of the action with resolved schema references. // Schema references that cannot be resolved (e.g., the action is not yet registered, // or the referenced schema has not been defined) are returned as-is. -func (a *Action[In, Out, StreamOut, StreamIn]) Desc() api.ActionDesc { +func (a *Action[In, Out, Stream]) Desc() api.ActionDesc { desc := *a.desc - if a.registry != nil { - if resolved, err := ResolveSchema(a.registry, desc.InputSchema); err == nil { - desc.InputSchema = resolved - } - if resolved, err := ResolveSchema(a.registry, desc.OutputSchema); err == nil { - desc.OutputSchema = resolved + if a.registry == nil { + return desc + } + for _, p := range []*map[string]any{&desc.InputSchema, &desc.OutputSchema, &desc.StreamSchema, &desc.InitSchema} { + if resolved, err := ResolveSchema(a.registry, *p); err == nil { + *p = resolved } } return desc } // Register registers the action with the given registry. -func (a *Action[In, Out, StreamOut, StreamIn]) Register(r api.Registry) { +func (a *Action[In, Out, Stream]) Register(r api.Registry) { a.registry = r r.RegisterAction(a.desc.Key, a) } -// StreamBidi starts a bidirectional streaming connection. -// Returns an error if the action is not a bidi action. -// A trace span is created that remains open for the lifetime of the connection. -func (a *Action[In, Out, StreamOut, StreamIn]) StreamBidi(ctx context.Context, in In) (*BidiConnection[StreamIn, Out, StreamOut], error) { - if a.bidiFn == nil { - return nil, NewError(FAILED_PRECONDITION, "StreamBidi called on non-bidi action %q", a.desc.Name) - } - - ctx, cancel := context.WithCancel(ctx) - conn := &BidiConnection[StreamIn, Out, StreamOut]{ - inputCh: make(chan StreamIn, 1), - streamCh: make(chan StreamOut, 1), - doneCh: make(chan struct{}), - ctx: ctx, - cancel: cancel, - } - - spanMetadata := &tracing.SpanMetadata{ - Name: a.desc.Name, - Type: "action", - Subtype: string(a.desc.Type), - Metadata: make(map[string]string), - } - if flowName := FlowNameFromContext(ctx); flowName != "" { - spanMetadata.Metadata["flow:name"] = flowName - } - - go func() { - defer close(conn.doneCh) - defer close(conn.streamCh) - output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, in, - func(ctx context.Context, in In) (Out, error) { - start := time.Now() - output, err := a.bidiFn(ctx, in, conn.inputCh, conn.streamCh) - latency := time.Since(start) - if err != nil { - metrics.WriteActionFailure(ctx, a.desc.Name, latency, err) - } else { - metrics.WriteActionSuccess(ctx, a.desc.Name, latency) - } - return output, err - }, - ) - conn.mu.Lock() - conn.output = output - conn.err = err - conn.mu.Unlock() - }() - - return conn, nil -} - // ResolveActionFor returns the action for the given key in the global registry, // or nil if there is none. -// It panics if the action is of the wrong api. -func ResolveActionFor[In, Out, StreamOut, StreamIn any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, StreamOut, StreamIn] { +// It panics if the action is of the wrong type. That includes bidi actions, +// which are a distinct type; resolve those via [ResolveBidiActionFor]. +func ResolveActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.ResolveAction(key) if a == nil { return nil } - return a.(*Action[In, Out, StreamOut, StreamIn]) + return a.(*Action[In, Out, Stream]) } // LookupActionFor returns the action for the given key in the global registry, @@ -475,138 +412,14 @@ func ResolveActionFor[In, Out, StreamOut, StreamIn any](r api.Registry, atype ap // It panics if the action is of the wrong api. // // Deprecated: Use ResolveActionFor. -func LookupActionFor[In, Out, StreamOut, StreamIn any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, StreamOut, StreamIn] { +func LookupActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.LookupAction(key) if a == nil { return nil } - return a.(*Action[In, Out, StreamOut, StreamIn]) -} - -// wrapBidiAsStreaming wraps a BidiFunc into a StreamingFunc for use with Run/RunJSON. -// The input is passed as the initial input to the bidi func, and the input stream -// channel is closed immediately (no streaming inputs). Outgoing stream chunks are -// forwarded to the callback. -func wrapBidiAsStreaming[In, Out, StreamOut, StreamIn any](fn BidiFunc[In, Out, StreamOut, StreamIn]) StreamingFunc[In, Out, StreamOut] { - return func(ctx context.Context, input In, cb StreamCallback[StreamOut]) (Out, error) { - inCh := make(chan StreamIn, 1) - outCh := make(chan StreamOut, 1) - doneCh := make(chan struct{}) - - var output Out - var fnErr error - - go func() { - defer close(doneCh) - defer close(outCh) - output, fnErr = fn(ctx, input, inCh, outCh) - }() - - // No streaming inputs when used as a non-bidi streaming action. - close(inCh) - - // Forward streamed chunks to the callback. - if cb != nil { - for chunk := range outCh { - if err := cb(ctx, chunk); err != nil { - return base.Zero[Out](), err - } - } - } else { - // Drain the channel even without a callback. - for range outCh { - } - } - - <-doneCh - return output, fnErr - } -} - -// BidiConnection represents an active bidirectional streaming session. -type BidiConnection[StreamIn, Out, StreamOut any] struct { - inputCh chan StreamIn - streamCh chan StreamOut - doneCh chan struct{} - output Out - err error - ctx context.Context - cancel context.CancelFunc - mu sync.Mutex - closed bool -} - -// Send sends an input message to the bidi action. -// Returns an error if the connection is closed or the context is cancelled. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Send(input StreamIn) (err error) { - defer func() { - if r := recover(); r != nil { - err = NewError(FAILED_PRECONDITION, "connection is closed") - } - }() - - select { - case c.inputCh <- input: - return nil - case <-c.ctx.Done(): - return c.ctx.Err() - case <-c.doneCh: - return NewError(FAILED_PRECONDITION, "action has completed") - } -} - -// Close signals that no more inputs will be sent. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return nil - } - c.closed = true - close(c.inputCh) - return nil -} - -// Receive returns an iterator for receiving streamed response chunks. -// The iterator completes when the action finishes. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Receive() iter.Seq2[StreamOut, error] { - return func(yield func(StreamOut, error) bool) { - for { - select { - case chunk, ok := <-c.streamCh: - if !ok { - return - } - if !yield(chunk, nil) { - c.cancel() - return - } - case <-c.ctx.Done(): - var zero StreamOut - yield(zero, c.ctx.Err()) - return - } - } - } + return a.(*Action[In, Out, Stream]) } -// Output returns the final output after the action completes. -// Blocks until done or context cancelled. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Output() (Out, error) { - select { - case <-c.doneCh: - c.mu.Lock() - defer c.mu.Unlock() - return c.output, c.err - case <-c.ctx.Done(): - var zero Out - return zero, c.ctx.Err() - } -} - -// Done returns a channel that is closed when the connection completes. -func (c *BidiConnection[StreamIn, Out, StreamOut]) Done() <-chan struct{} { - return c.doneCh -} +var _ api.Action = (*Action[struct{}, struct{}, struct{}])(nil) diff --git a/go/core/action_test.go b/go/core/action_test.go index 8f9fc13605..9508a4d431 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -367,7 +367,7 @@ func TestResolveActionFor(t *testing.T) { } DefineAction(r, "test/resolvable", api.ActionTypeCustom, nil, nil, fn) - found := ResolveActionFor[int, int, struct{}, struct{}](r, api.ActionTypeCustom, "test/resolvable") + found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/resolvable") if found == nil { t.Fatal("ResolveActionFor returned nil") @@ -380,7 +380,7 @@ func TestResolveActionFor(t *testing.T) { t.Run("returns nil for non-existent action", func(t *testing.T) { r := registry.New() - found := ResolveActionFor[int, int, struct{}, struct{}](r, api.ActionTypeCustom, "test/nonexistent") + found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/nonexistent") if found != nil { t.Errorf("ResolveActionFor returned %v, want nil", found) @@ -396,7 +396,7 @@ func TestLookupActionFor(t *testing.T) { } DefineAction(r, "test/lookupable", api.ActionTypeCustom, nil, nil, fn) - found := LookupActionFor[string, string, struct{}, struct{}](r, api.ActionTypeCustom, "test/lookupable") + found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/lookupable") if found == nil { t.Fatal("LookupActionFor returned nil") @@ -406,7 +406,7 @@ func TestLookupActionFor(t *testing.T) { t.Run("returns nil for non-existent action", func(t *testing.T) { r := registry.New() - found := LookupActionFor[string, string, struct{}, struct{}](r, api.ActionTypeCustom, "test/missing") + found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/missing") if found != nil { t.Errorf("LookupActionFor returned %v, want nil", found) diff --git a/go/core/api/action.go b/go/core/api/action.go index 91ee2a646d..5818108a30 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -19,6 +19,7 @@ package api import ( "context" "encoding/json" + "iter" ) type ActionRunResult[T any] struct { @@ -40,6 +41,58 @@ type Action interface { Desc() ActionDesc } +// BidiSessionOptions configures a bidirectional session started through the +// JSON interfaces. A nil value is equivalent to zero options. The struct may +// gain fields over time; construct it by field name. +// +// Experimental: bidirectional streaming is experimental and subject to change. +type BidiSessionOptions struct { + // Init is the JSON-encoded initial configuration for the session, + // decoded into the action's Init type and validated against its + // InitSchema. Empty or JSON-null means no init (the zero Init value). + Init json.RawMessage +} + +// BidiAction is implemented by actions that support bidirectional streaming. +// Non-bidi actions do not implement this interface; callers may detect bidi +// support with a type assertion. The descriptor's "bidi" metadata carries the +// same signal for tooling that only sees serialized descriptors. +// +// Experimental: bidirectional streaming is experimental and subject to change. +type BidiAction interface { + Action + // RunBidiJSON runs the bidi action as a single one-shot call: input is + // delivered as the only chunk on the input stream, outgoing chunks are + // forwarded to cb, and opts carries the session init. + RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *BidiSessionOptions) (*ActionRunResult[json.RawMessage], error) + // StreamBidiJSON starts a bidirectional streaming session using + // JSON-encoded messages. + StreamBidiJSON(ctx context.Context, opts *BidiSessionOptions) (BidiJSONConnection, error) +} + +// BidiJSONConnection is a JSON-encoded view of an active bidirectional +// streaming session. It mirrors the typed BidiConnection API but works in +// terms of json.RawMessage payloads, allowing generic transports (e.g. the +// reflection API) to wire bidi actions without knowing their concrete types. +// +// Experimental: bidirectional streaming is experimental and subject to change. +type BidiJSONConnection interface { + // Send encodes chunk as the action's In type and sends it to the action. + // A chunk that fails to decode or validate against the action's input + // schema fails the session: the error is returned and also becomes the + // session's terminal error, reported by Output. Transports that need + // per-chunk tolerance must validate before calling Send. + Send(chunk json.RawMessage) error + // Close signals that no more inputs will be sent. + Close() error + // Receive yields outgoing stream chunks encoded as JSON. The iterator + // completes when the action finishes. + Receive() iter.Seq2[json.RawMessage, error] + // Output returns the final output encoded as JSON, blocking until the + // action completes or the context is cancelled. + Output() (json.RawMessage, error) +} + // Registerable allows a primitive to be registered with a registry. type Registerable interface { Register(r Registry) @@ -72,9 +125,9 @@ type ActionDesc struct { Key string `json:"key"` // Key of the action. Name string `json:"name"` // Name of the action. Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputschema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputschema"` // JSON schema to validate against the action's output. - StreamSchema map[string]any `json:"streamschema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. - InitSchema map[string]any `json:"initschema,omitempty"` // JSON schema to validate against the action's incoming stream messages (bidi only). + InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. + StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema to validate against the action's outgoing streamed chunks. + InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema to validate against the action's initial configuration (bidi only). Metadata map[string]any `json:"metadata"` // Metadata for the action. } diff --git a/go/core/background_action.go b/go/core/background_action.go index 5d5c9fbe78..e3777f01ff 100644 --- a/go/core/background_action.go +++ b/go/core/background_action.go @@ -45,10 +45,10 @@ type Operation[Out any] struct { // // For internal use only. type BackgroundActionDef[In, Out any] struct { - *Action[In, *Operation[Out], struct{}, struct{}] + *Action[In, *Operation[Out], struct{}] - check *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] // Sub-action that checks the status of a background operation. - cancel *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] // Sub-action that cancels a background operation. + check *Action[*Operation[Out], *Operation[Out], struct{}] // Sub-action that checks the status of a background operation. + cancel *Action[*Operation[Out], *Operation[Out], struct{}] // Sub-action that cancels a background operation. } // Start starts a background operation. @@ -140,7 +140,7 @@ func NewBackgroundAction[In, Out any]( return updatedOp, nil }) - var cancelAction *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] + var cancelAction *Action[*Operation[Out], *Operation[Out], struct{}] if cancelFn != nil { cancelAction = NewAction(name, api.ActionTypeCancelOperation, metadata, nil, func(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) { @@ -165,17 +165,17 @@ func LookupBackgroundAction[In, Out any](r api.Registry, key string) *Background atype, provider, id := api.ParseKey(key) name := api.NewName(provider, id) - startAction := ResolveActionFor[In, *Operation[Out], struct{}, struct{}](r, atype, name) + startAction := ResolveActionFor[In, *Operation[Out], struct{}](r, atype, name) if startAction == nil { return nil } - checkAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}, struct{}](r, api.ActionTypeCheckOperation, name) + checkAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCheckOperation, name) if checkAction == nil { return nil } - cancelAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}, struct{}](r, api.ActionTypeCancelOperation, name) + cancelAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCancelOperation, name) return &BackgroundActionDef[In, Out]{ Action: startAction, diff --git a/go/core/bidi.go b/go/core/bidi.go new file mode 100644 index 0000000000..cad8fc3cc9 --- /dev/null +++ b/go/core/bidi.go @@ -0,0 +1,696 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "encoding/json" + "errors" + "iter" + "maps" + "sync" + "time" + + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal/base" +) + +// BidiFunc is the function signature for bidirectional streaming actions. +// It receives an initial configuration of type Init, reads incoming stream +// messages of type In from inCh, and writes outgoing stream messages of type +// Stream to outCh. It returns a final output of type Out when complete. +// +// The function must honor ctx cancellation: the framework signals shutdown +// (consumer error, invalid inbound chunk on the JSON transport, session +// cancellation) by cancelling ctx, and a function that ignores it blocks its +// session indefinitely. The framework owns closing outCh; the function must +// never close it. Writes to outCh apply backpressure: they block until the +// consumer reads earlier chunks. A panic in the function is recovered and +// reported as an INTERNAL error rather than crashing the process, since the +// function runs in a framework-owned goroutine. +// +// Experimental: bidirectional streaming is experimental and subject to change. +type BidiFunc[In, Out, Stream, Init any] = func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) + +// A BidiAction is a named, observable bidirectional streaming operation. It +// receives an initial configuration of type Init when a session starts, then +// consumes a stream of In messages while producing a stream of Stream chunks, +// and finishes with a final output of type Out. +// +// BidiAction embeds [Action], so it can also be invoked through the regular +// unary surface (Run, RunJSON): the input is delivered as a single chunk on +// the input stream with the zero Init value. Use [BidiAction.RunWithInit] or +// [BidiAction.RunBidiJSON] for one-shot calls that supply init. +// +// For internal use only. +// +// Experimental: bidirectional streaming is experimental and subject to change. +type BidiAction[In, Out, Stream, Init any] struct { + *Action[In, Out, Stream] + bidiFn BidiFunc[In, Out, Stream, Init] +} + +// BidiActionOptions configures a bidi action. Nil schema fields are inferred +// from the corresponding type parameters. +// +// Experimental: bidirectional streaming is experimental and subject to change. +type BidiActionOptions struct { + Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. + InputSchema map[string]any // JSON schema for messages streamed into the action. Inferred from In if nil. + OutputSchema map[string]any // JSON schema for the action's final output. Inferred from Out if nil. + StreamSchema map[string]any // JSON schema for outgoing streamed chunks. Inferred from Stream if nil. + InitSchema map[string]any // JSON schema for the session's initial configuration. Inferred from Init if nil. +} + +// NewBidiAction creates a new bidirectional streaming [BidiAction] without registering it. +// +// Experimental: bidirectional streaming is experimental and subject to change. +func NewBidiAction[In, Out, Stream, Init any]( + name string, + atype api.ActionType, + opts *BidiActionOptions, + fn BidiFunc[In, Out, Stream, Init], +) *BidiAction[In, Out, Stream, Init] { + if opts == nil { + opts = &BidiActionOptions{} + } + + metadata := make(map[string]any, len(opts.Metadata)+1) + maps.Copy(metadata, opts.Metadata) + metadata["bidi"] = true + + b := &BidiAction[In, Out, Stream, Init]{ + Action: newAction[In, Out, Stream](name, atype, metadata, opts.InputSchema), + bidiFn: fn, + } + // The embedded action's fn backs the promoted unary surface (Run, + // RunJSON): a one-shot session with the zero Init value. + b.Action.fn = b.oneShotFn(base.Zero[Init]()) + + if opts.OutputSchema != nil { + b.desc.OutputSchema = opts.OutputSchema + } + if opts.StreamSchema != nil { + b.desc.StreamSchema = opts.StreamSchema + } + + if opts.InitSchema != nil { + b.desc.InitSchema = opts.InitSchema + } else if !isUnitType[Init]() { + b.desc.InitSchema = inferSchema[Init]() + } + + return b +} + +// DefineBidiAction creates and registers a bidirectional streaming [BidiAction]. +// +// Experimental: bidirectional streaming is experimental and subject to change. +func DefineBidiAction[In, Out, Stream, Init any]( + r api.Registry, + name string, + atype api.ActionType, + opts *BidiActionOptions, + fn BidiFunc[In, Out, Stream, Init], +) *BidiAction[In, Out, Stream, Init] { + b := NewBidiAction(name, atype, opts, fn) + b.Register(r) + return b +} + +// Register registers the bidi action with the given registry. It overrides +// the embedded Action's Register so that the registry holds the BidiAction +// itself; registry lookups must satisfy api.BidiAction. +func (b *BidiAction[In, Out, Stream, Init]) Register(r api.Registry) { + b.Action.registry = r + r.RegisterAction(b.desc.Key, b) +} + +// oneShotFn adapts the bidi function into a single streaming call with the +// given init: the call's input becomes the only chunk on the input stream and +// outgoing chunks are forwarded to cb. The call's span and metrics come from +// the unary path (runWithTelemetry), so unlike startBidi the connection runs +// the function bare. Init is validated inside the call so that validation +// failures are recorded on the action's trace span, like input validation +// failures. +func (b *BidiAction[In, Out, Stream, Init]) oneShotFn(init Init) StreamingFunc[In, Out, Stream] { + return func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { + if err := b.validateInit(init); err != nil { + return base.Zero[Out](), err + } + + conn := newBidiConnection[In, Out, Stream](ctx) + // Released on every exit path, including a panicking callback below: + // an unwinding panic must not strand the function goroutine blocked + // on a stream write with no consumer. A cause recorded earlier (cb + // error) wins; the first cancel is sticky. + defer conn.cancel(nil) + go conn.run(b.desc.Name, func(ctx context.Context) (Out, error) { + return callBidiFn(ctx, b.desc.Name, b.bidiFn, init, conn.inputCh, conn.streamCh) + }) + + // inputCh is buffered, so delivering the single input cannot block. + conn.inputCh <- input + close(conn.inputCh) + + // Drain the stream until the function returns so its goroutine is + // never blocked on a stream write, even after a callback failure + // cancels the function's context. + var cbErr error + for chunk := range conn.streamCh { + if cb == nil || cbErr != nil { + continue + } + if err := cb(ctx, chunk); err != nil { + cbErr = err + conn.cancel(err) + } + } + <-conn.doneCh + if cbErr != nil { + return base.Zero[Out](), cbErr + } + conn.mu.Lock() + defer conn.mu.Unlock() + return conn.output, conn.err + } +} + +// spanInitValue returns the value to record as the span's genkit:init +// attribute, or nil when Init is the no-init sentinel. +func (b *BidiAction[In, Out, Stream, Init]) spanInitValue(init Init) any { + if isUnitType[Init]() { + return nil + } + return init +} + +// RunWithInit executes the bidi action as a single one-shot call with the +// given initial configuration: input is delivered as the only chunk on the +// input stream and outgoing chunks are forwarded to cb. Returns an error if +// init fails validation against the action's InitSchema. +// +// Experimental: bidirectional streaming is experimental and subject to change. +func (b *BidiAction[In, Out, Stream, Init]) RunWithInit(ctx context.Context, init Init, input In, cb StreamCallback[Stream]) (Out, error) { + r, err := b.Action.runWithTelemetry(ctx, input, cb, b.oneShotFn(init), b.spanInitValue(init)) + if err != nil { + return base.Zero[Out](), err + } + return r.Result, nil +} + +// RunBidiJSON runs the bidi action as a single one-shot call: input is +// delivered as the only chunk on the input stream, outgoing chunks are +// forwarded to cb, and opts carries the session init. Returns an error if +// init fails to decode or validate. +// +// Experimental: bidirectional streaming is experimental and subject to change. +func (b *BidiAction[In, Out, Stream, Init]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], opts *api.BidiSessionOptions) (*api.ActionRunResult[json.RawMessage], error) { + init, hasInit, err := b.decodeInit(opts) + if err != nil { + return nil, err + } + var spanInit any + if hasInit { + spanInit = init + } + return b.Action.runJSONWithTelemetry(ctx, input, cb, b.oneShotFn(init), spanInit) +} + +// StreamBidi starts a bidirectional streaming connection with the given +// initial configuration. For actions whose Init type is struct{} (no init), +// pass struct{}{}. Returns an error if init fails validation against the +// action's InitSchema. +// A trace span is created that remains open for the lifetime of the connection. +// +// Experimental: bidirectional streaming is experimental and subject to change. +func (b *BidiAction[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) { + if err := b.validateInit(init); err != nil { + return nil, err + } + return b.startBidi(ctx, init, b.spanInitValue(init)), nil +} + +// StreamBidiJSON starts a bidirectional streaming session using JSON-encoded +// messages. Returns an error if the init carried by opts fails to decode or +// validate. +// +// Experimental: bidirectional streaming is experimental and subject to change. +func (b *BidiAction[In, Out, Stream, Init]) StreamBidiJSON(ctx context.Context, opts *api.BidiSessionOptions) (api.BidiJSONConnection, error) { + init, hasInit, err := b.decodeInit(opts) + if err != nil { + return nil, err + } + if err := b.validateInit(init); err != nil { + return nil, err + } + inputSchema, err := ResolveSchema(b.registry, b.desc.InputSchema) + if err != nil { + return nil, NewError(INVALID_ARGUMENT, "invalid input schema for action %q: %v", b.desc.Key, err) + } + // Compiled once per session: Send validates every inbound chunk, and + // recompiling the schema per chunk would dominate the streaming hot path. + compiledInput, err := base.CompileSchema(inputSchema) + if err != nil { + return nil, NewError(INVALID_ARGUMENT, "invalid input schema for action %q: %v", b.desc.Key, err) + } + // Like RunBidiJSON, record init on the span only when the client actually + // supplied one; the zero value from an absent init is not meaningful. + var spanInit any + if hasInit { + spanInit = init + } + conn := b.startBidi(ctx, init, spanInit) + return &bidiJSONConn[In, Out, Stream]{ + conn: conn, + key: b.desc.Key, + inputSchema: inputSchema, + compiledInput: compiledInput, + }, nil +} + +// decodeInit decodes the JSON init payload from opts into the action's Init +// type. Returns hasInit=false when opts is nil or the payload is empty or +// JSON null, so transports can pass the request's init field through +// unconditionally. +func (b *BidiAction[In, Out, Stream, Init]) decodeInit(opts *api.BidiSessionOptions) (Init, bool, error) { + var init Init + if opts == nil || !base.HasJSONValue(opts.Init) { + return init, false, nil + } + if err := json.Unmarshal(opts.Init, &init); err != nil { + return init, false, NewError(INVALID_ARGUMENT, "invalid init for action %q: %v", b.desc.Key, err) + } + return init, true, nil +} + +// validateInit checks an init value against the action's InitSchema (if any), +// resolving schema $refs through the registry first. Validation runs whenever +// InitSchema is present, even for the zero init value, so a required field +// surfaces as INVALID_ARGUMENT rather than silently defaulting. +func (b *BidiAction[In, Out, Stream, Init]) validateInit(init Init) error { + if b.desc.InitSchema == nil { + return nil + } + schema, err := ResolveSchema(b.registry, b.desc.InitSchema) + if err != nil { + return NewError(INVALID_ARGUMENT, "invalid init schema for action %q: %v", b.desc.Key, err) + } + if err := base.ValidateValue(init, schema); err != nil { + return NewError(INVALID_ARGUMENT, "invalid init for action %q: %v", b.desc.Key, err) + } + return nil +} + +// startBidi launches the bidi function in a goroutine with the given initial +// configuration and returns a live connection for sending/receiving chunks. +// The session gets its own span (open for the connection's lifetime) and +// metrics, unlike the one-shot path, which gets both from runWithTelemetry. +// spanInit, when non-nil, is recorded as the span's genkit:init attribute. +func (b *BidiAction[In, Out, Stream, Init]) startBidi(ctx context.Context, init Init, spanInit any) *BidiConnection[In, Out, Stream] { + conn := newBidiConnection[In, Out, Stream](ctx) + + // Init is recorded as its own span attribute (genkit:init), not as the + // span input: the input slot describes per-call input, which a bidi + // session receives incrementally over the connection. + spanMetadata := b.spanMetadata(ctx, spanInit) + + go conn.run(b.desc.Name, func(ctx context.Context) (Out, error) { + return tracing.RunInNewSpan(ctx, spanMetadata, nil, + func(ctx context.Context, _ any) (out Out, err error) { + start := time.Now() + defer func() { recordActionMetrics(ctx, b.desc.Name, start, err) }() + out, err = callBidiFn(ctx, b.desc.Name, b.bidiFn, init, conn.inputCh, conn.streamCh) + if err != nil { + return out, err + } + // Mirror the unary path: the final output is validated + // against the action's OutputSchema. + outputSchema, err := b.resolveOutputSchema() + if err != nil { + return out, err + } + return out, b.validateOutput(out, outputSchema) + }, + ) + }) + + return conn +} + +// ResolveBidiActionFor returns the bidi action for the given name in the +// registry, or nil if there is none. +// It panics if the action is of the wrong type; plain actions resolve via +// [ResolveActionFor]. +// +// Experimental: bidirectional streaming is experimental and subject to change. +func ResolveBidiActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *BidiAction[In, Out, Stream, Init] { + provider, id := api.ParseName(name) + key := api.NewKey(atype, provider, id) + a := r.ResolveAction(key) + if a == nil { + return nil + } + return a.(*BidiAction[In, Out, Stream, Init]) +} + +// callBidiFn invokes the bidi function, converting a panic into an INTERNAL +// error. The function runs in a framework-owned goroutine, so an unrecovered +// panic would crash the process rather than fail the session. +func callBidiFn[In, Out, Stream, Init any]( + ctx context.Context, + name string, + fn BidiFunc[In, Out, Stream, Init], + init Init, + inCh <-chan In, + outCh chan<- Stream, +) (out Out, err error) { + defer func() { + if r := recover(); r != nil { + err = NewError(INTERNAL, "panic in bidi action %q: %v", name, r) + } + }() + return fn(ctx, init, inCh, outCh) +} + +// BidiConnection represents an active bidirectional streaming session. +// +// The connection applies backpressure: the action blocks writing a chunk +// until the consumer reads earlier ones, so a session that streams more than +// one chunk requires the caller to drain [BidiConnection.Receive] before (or +// concurrently with) waiting on [BidiConnection.Output]. +// +// Experimental: bidirectional streaming is experimental and subject to change. +type BidiConnection[In, Out, Stream any] struct { + inputCh chan In + streamCh chan Stream + doneCh chan struct{} + output Out + err error + ctx context.Context + cancel context.CancelCauseFunc + mu sync.Mutex + closed bool +} + +// newBidiConnection creates an idle connection whose context derives from ctx. +// The caller must start exactly one [BidiConnection.run] goroutine to operate +// it. The context carries a cancel cause so that an abort reason (e.g. an +// invalid inbound chunk poisoning the session) survives to Send/Receive/Output +// instead of flattening to context.Canceled. +func newBidiConnection[In, Out, Stream any](ctx context.Context) *BidiConnection[In, Out, Stream] { + ctx, cancel := context.WithCancelCause(ctx) + return &BidiConnection[In, Out, Stream]{ + inputCh: make(chan In, 1), + streamCh: make(chan Stream, 1), + doneCh: make(chan struct{}), + ctx: ctx, + cancel: cancel, + } +} + +// ctxErr returns the reason the connection's context was cancelled, preferring +// the recorded cause over the bare context error. Only meaningful once the +// context is done. +func (c *BidiConnection[In, Out, Stream]) ctxErr() error { + if cause := context.Cause(c.ctx); cause != nil { + return cause + } + return c.ctx.Err() +} + +// run executes fn, which reads c.inputCh and writes c.streamCh, then records +// its result and settles the connection. It must be called exactly once, in +// its own goroutine. fn receives the connection's context and must honor its +// cancellation; convert panics inside fn with callBidiFn. +func (c *BidiConnection[In, Out, Stream]) run(name string, fn func(context.Context) (Out, error)) { + // Deferred calls run in reverse order: the stream channel closes first, + // then doneCh signals completion, then the connection context is + // released. Receive/Output rely on this ordering to prefer delivering + // results over reporting cancellation. + defer c.cancel(nil) + defer close(c.doneCh) + closingStream := false + defer func() { + if r := recover(); r != nil { + c.mu.Lock() + if c.err == nil { + if closingStream { + // The close below panicked: the action closed the output + // channel itself, which the framework owns. + c.err = NewError(INTERNAL, "bidi action %q closed its output channel; the framework owns closing it", name) + } else { + // A panic escaped fn's own wrapping (span, schema + // resolution, metrics); report it as what it is rather + // than misattributing it to the channel close. + c.err = NewError(INTERNAL, "panic in bidi session %q: %v", name, r) + } + } + c.mu.Unlock() + } + }() + defer func() { + // closingStream brackets the close so the recover above can tell a + // double-close panic apart from one unwinding out of fn. + closingStream = true + close(c.streamCh) + closingStream = false + }() + output, err := fn(c.ctx) + // An abort recorded a cause (invalid inbound chunk, failed stream + // marshal, callback error): that cause is the session's terminal error. + // It overrides a nil error from a function that never observed the + // cancellation and the bare Canceled it unwound with, but not a distinct + // error the function chose to report. + if cause := context.Cause(c.ctx); cause != nil && !errors.Is(cause, context.Canceled) { + if err == nil || errors.Is(err, context.Canceled) { + err = cause + } + } + c.mu.Lock() + c.output = output + c.err = err + c.mu.Unlock() +} + +// Send sends an input message to the bidi action. It blocks until the action +// reads the message (backpressure), the connection is cancelled, or the +// action completes. Returns an error if the connection is closed or the +// context is cancelled. Typed inputs are not re-validated against the +// action's InputSchema; the JSON transport path is. +func (c *BidiConnection[In, Out, Stream]) Send(input In) (err error) { + // Close may close inputCh concurrently with this send; sending on a + // closed channel panics, and the recover converts that into the same + // "connection is closed" error a pre-checked Send would return. + defer func() { + if r := recover(); r != nil { + err = NewError(FAILED_PRECONDITION, "connection is closed") + } + }() + + // A completed or aborted connection must fail deterministically: in the + // blocking select below all arms can be ready at once (inputCh keeps a + // free buffer slot once the action exits) and the runtime picks one at + // random, which would let a post-completion Send "succeed" into a + // channel nothing reads. Like Output, completion is preferred over + // cancellation. + select { + case <-c.doneCh: + return NewError(FAILED_PRECONDITION, "action has completed") + default: + } + select { + case <-c.ctx.Done(): + return c.ctxErr() + default: + } + + select { + case c.inputCh <- input: + return nil + case <-c.ctx.Done(): + return c.ctxErr() + case <-c.doneCh: + return NewError(FAILED_PRECONDITION, "action has completed") + } +} + +// Close signals that no more inputs will be sent. +func (c *BidiConnection[In, Out, Stream]) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } + c.closed = true + close(c.inputCh) + return nil +} + +// Receive returns an iterator for receiving streamed response chunks. +// The iterator completes when the action finishes. +// +// Breaking out of the loop stops consumption but does not abort the session: +// the action keeps running and later chunks remain subject to backpressure +// until Receive is iterated again or the session ends. Use +// [BidiConnection.Cancel] to abort the session. Chunks are delivered to a +// single consumer; concurrent Receive iterations split the stream between +// them. +func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { + return func(yield func(Stream, error) bool) { + for { + select { + case chunk, ok := <-c.streamCh: + if !ok { + return + } + if !yield(chunk, nil) { + return + } + case <-c.ctx.Done(): + // Completion closes the stream channel before releasing the + // connection context, so prefer delivering chunks (and the + // clean end of stream) over reporting cancellation. + for { + select { + case chunk, ok := <-c.streamCh: + if !ok { + return + } + if !yield(chunk, nil) { + return + } + default: + var zero Stream + yield(zero, c.ctxErr()) + return + } + } + } + } + } +} + +// Output returns the final output after the action completes. +// Blocks until done or context cancelled. If the action streams more than +// one chunk, [BidiConnection.Receive] must be drained for the action to +// finish; see the BidiConnection doc. +func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) { + select { + case <-c.doneCh: + case <-c.ctx.Done(): + // Completion closes doneCh before releasing the connection context, + // but both may be ready when this select runs; prefer the result. + select { + case <-c.doneCh: + default: + var zero Out + return zero, c.ctxErr() + } + } + c.mu.Lock() + defer c.mu.Unlock() + return c.output, c.err +} + +// Cancel aborts the session by cancelling the connection's context: the +// action's context is cancelled, blocked Sends unblock, and Output reports +// the cancellation error unless the action already completed. Safe to call +// multiple times and after completion. +func (c *BidiConnection[In, Out, Stream]) Cancel() { + c.cancel(nil) +} + +// Done returns a channel that is closed when the connection completes. +func (c *BidiConnection[In, Out, Stream]) Done() <-chan struct{} { + return c.doneCh +} + +// bidiJSONConn adapts a typed BidiConnection to the JSON-encoded +// api.BidiJSONConnection interface. +type bidiJSONConn[In, Out, Stream any] struct { + conn *BidiConnection[In, Out, Stream] + key string // action key, for error messages + inputSchema map[string]any // resolved InputSchema, used for chunk normalization + compiledInput *base.CompiledSchema // inputSchema compiled once; every inbound chunk validates against it +} + +func (b *bidiJSONConn[In, Out, Stream]) Send(chunk json.RawMessage) error { + // Mirrors the unary RunJSON path: normalize and validate every inbound + // chunk against the action's input schema, since JSON transports carry + // untrusted payloads. An explicit JSON null is validated like any other + // payload. + in, err := base.UnmarshalAndNormalizeWith[In](chunk, b.inputSchema, b.compiledInput) + if err == nil && len(chunk) == 0 { + // UnmarshalAndNormalizeWith skips validation entirely for empty + // input; validate the zero value it produced so an absent chunk + // payload cannot bypass a schema the zero value does not satisfy, + // mirroring the unary path, which validates the decoded value. + err = b.compiledInput.ValidateValue(in) + } + if err != nil { + // An invalid chunk fails the session (matching the JS runtime and + // the one-shot path, where invalid input fails the call): the error + // poisons the connection as its cancel cause so Output reports it, + // and is also returned for the transport to log or relay. + err = NewError(INVALID_ARGUMENT, "invalid stream chunk for action %q: %v", b.key, err) + b.conn.cancel(err) + return err + } + return b.conn.Send(in) +} + +func (b *bidiJSONConn[In, Out, Stream]) Close() error { + return b.conn.Close() +} + +func (b *bidiJSONConn[In, Out, Stream]) Receive() iter.Seq2[json.RawMessage, error] { + return func(yield func(json.RawMessage, error) bool) { + for chunk, err := range b.conn.Receive() { + if err != nil { + yield(nil, err) + return + } + bytes, mErr := json.Marshal(chunk) + if mErr != nil { + // Later chunks of the same type would fail the same way, + // leaving the session running with no consumer; abort it + // with the marshal error as the cause so Output reports it. + b.conn.cancel(mErr) + yield(nil, mErr) + return + } + if !yield(bytes, nil) { + return + } + } + } +} + +func (b *bidiJSONConn[In, Out, Stream]) Output() (json.RawMessage, error) { + out, err := b.conn.Output() + if err != nil { + return nil, err + } + return json.Marshal(out) +} + +var ( + _ api.Action = (*BidiAction[struct{}, struct{}, struct{}, struct{}])(nil) + _ api.BidiAction = (*BidiAction[struct{}, struct{}, struct{}, struct{}])(nil) +) diff --git a/go/core/bidi_test.go b/go/core/bidi_test.go new file mode 100644 index 0000000000..91fea75ebf --- /dev/null +++ b/go/core/bidi_test.go @@ -0,0 +1,1154 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math" + "strings" + "sync" + "testing" + "time" + + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/registry" +) + +func TestBidiActionEcho(t *testing.T) { + ctx := context.Background() + + // In=string (stream chunks), Out=string, Stream=string, Init=struct{} (no init data). + action := NewBidiAction( + "echo", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var count int + for input := range inCh { + count++ + outCh <- fmt.Sprintf("echo: %s", input) + } + return fmt.Sprintf("processed %d messages", count), nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + // With unbuffered channels, we must send and receive concurrently. + go func() { + conn.Send("hello") + conn.Send("world") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 2 { + t.Fatalf("expected 2 chunks, got %d: %v", len(chunks), chunks) + } + if chunks[0] != "echo: hello" { + t.Errorf("expected 'echo: hello', got %q", chunks[0]) + } + if chunks[1] != "echo: world" { + t.Errorf("expected 'echo: world', got %q", chunks[1]) + } + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "processed 2 messages" { + t.Errorf("expected 'processed 2 messages', got %q", output) + } +} + +func TestBidiActionWithConfig(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string + } + + // In=string (stream chunks), Out=string, Stream=string, Init=Config. + action := NewBidiAction( + "prefixed", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + outCh <- fmt.Sprintf("%s: %s", cfg.Prefix, input) + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, Config{Prefix: "INFO"}) + if err != nil { + t.Fatal(err) + } + + go func() { + conn.Send("test message") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 1 || chunks[0] != "INFO: test message" { + t.Errorf("unexpected chunks: %v", chunks) + } +} + +// TestRunWithInit verifies the typed one-shot path: input is delivered as a +// single chunk and init configures the session. +func TestRunWithInit(t *testing.T) { + ctx := context.Background() + + type Config struct{ Prefix string } + + action := NewBidiAction( + "prefixed-oneshot", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + var out string + for in := range inCh { + out = cfg.Prefix + in + } + return out, nil + }, + ) + + got, err := action.RunWithInit(ctx, Config{Prefix: ">> "}, "hello", nil) + if err != nil { + t.Fatalf("RunWithInit: %v", err) + } + if got != ">> hello" { + t.Errorf("output = %q, want %q", got, ">> hello") + } +} + +// TestBidiActionInterfaceDetection verifies that registry lookups return +// values whose api.BidiAction conformance matches the action kind: bidi +// actions satisfy it (pinning the BidiAction.Register override, which must +// register the bidi type rather than the embedded Action) and plain actions +// do not (the basis for transports' fail-loud init handling). +func TestBidiActionInterfaceDetection(t *testing.T) { + r := registry.New() + + DefineAction(r, "plain", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, in string) (string, error) { + return "out:" + in, nil + }) + DefineBidiAction(r, "bidi", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "done", nil + }) + + plain := r.LookupAction("/custom/plain") + if plain == nil { + t.Fatal("plain action not registered") + } + if _, ok := plain.(api.BidiAction); ok { + t.Error("plain action must not satisfy api.BidiAction") + } + + bidi := r.LookupAction("/custom/bidi") + if bidi == nil { + t.Fatal("bidi action not registered") + } + if _, ok := bidi.(api.BidiAction); !ok { + t.Error("bidi action must satisfy api.BidiAction") + } +} + +// TestRunBidiJSON verifies the JSON one-shot path used by transports: input +// is delivered as a single chunk and opts carries the session init. +func TestRunBidiJSON(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string `json:"prefix"` + } + + action := NewBidiAction( + "prefixed-json", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + for in := range inCh { + outCh <- cfg.Prefix + in + } + return "done", nil + }, + ) + + var chunks []string + cb := func(_ context.Context, raw json.RawMessage) error { + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return err + } + chunks = append(chunks, s) + return nil + } + + r, err := action.RunBidiJSON(ctx, json.RawMessage(`"hello"`), cb, + &api.BidiSessionOptions{Init: json.RawMessage(`{"prefix":">> "}`)}) + if err != nil { + t.Fatalf("RunBidiJSON: %v", err) + } + var got string + if err := json.Unmarshal(r.Result, &got); err != nil { + t.Fatalf("unmarshal output: %v", err) + } + if got != "done" { + t.Errorf("output = %q, want %q", got, "done") + } + if len(chunks) != 1 || chunks[0] != ">> hello" { + t.Errorf("chunks = %v, want [\">> hello\"]", chunks) + } + if r.TraceId == "" { + t.Error("TraceId is empty") + } +} + +// TestRunBidiJSONInvalidInit verifies that a malformed JSON init payload +// surfaces as INVALID_ARGUMENT. +func TestRunBidiJSONInvalidInit(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string `json:"prefix"` + } + action := NewBidiAction( + "bad-json-init", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + return "", nil + }, + ) + + _, err := action.RunBidiJSON(ctx, json.RawMessage(`"in"`), nil, + &api.BidiSessionOptions{Init: json.RawMessage(`{not json`)}) + if err == nil { + t.Fatal("expected error for invalid JSON, got nil") + } + gerr, ok := err.(*GenkitError) + if !ok { + t.Fatalf("expected *GenkitError, got %T: %v", err, err) + } + if gerr.Status != INVALID_ARGUMENT { + t.Errorf("status = %v, want %v", gerr.Status, INVALID_ARGUMENT) + } +} + +// TestStreamBidiJSONNullInit verifies that nil options and a JSON-null init +// payload are both treated as no init (the zero Init value). +func TestStreamBidiJSONNullInit(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string `json:"prefix"` + } + + for _, opts := range []*api.BidiSessionOptions{nil, {Init: json.RawMessage(`null`)}} { + var sawInit Config + action := NewBidiAction( + "null-init", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + sawInit = cfg + for range inCh { + } + return "done", nil + }, + ) + + conn, err := action.StreamBidiJSON(ctx, opts) + if err != nil { + t.Fatalf("StreamBidiJSON(%v): %v", opts, err) + } + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + if sawInit != (Config{}) { + t.Errorf("init = %v, want zero value", sawInit) + } + } +} + +// TestInitSchemaValidationRejectsBadInit verifies that init is validated +// against the action's InitSchema and a mismatch surfaces as INVALID_ARGUMENT. +func TestInitSchemaValidationRejectsBadInit(t *testing.T) { + ctx := context.Background() + + // Init schema requires "prefix" to be a string. + initSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "prefix": map[string]any{"type": "string"}, + }, + "required": []any{"prefix"}, + } + + action := NewBidiAction( + "validated-init", api.ActionTypeCustom, + &BidiActionOptions{InitSchema: initSchema}, + func(ctx context.Context, cfg map[string]any, inCh <-chan string, outCh chan<- string) (string, error) { + return "done", nil + }, + ) + + // Missing required "prefix" field. + _, err := action.StreamBidi(ctx, map[string]any{"other": 1}) + if err == nil { + t.Fatal("expected validation error, got nil") + } + gerr, ok := err.(*GenkitError) + if !ok { + t.Fatalf("expected *GenkitError, got %T: %v", err, err) + } + if gerr.Status != INVALID_ARGUMENT { + t.Errorf("status = %v, want %v", gerr.Status, INVALID_ARGUMENT) + } +} + +// TestInitSchemaValidationAcceptsGoodInit verifies the matching-init path of +// init schema validation. +func TestInitSchemaValidationAcceptsGoodInit(t *testing.T) { + ctx := context.Background() + + initSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "prefix": map[string]any{"type": "string"}, + }, + "required": []any{"prefix"}, + } + + action := NewBidiAction( + "validated-init-ok", api.ActionTypeCustom, + &BidiActionOptions{InitSchema: initSchema}, + func(ctx context.Context, cfg map[string]any, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, map[string]any{"prefix": ">> "}) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out != "done" { + t.Errorf("output = %q, want %q", out, "done") + } +} + +func TestBidiConnectionSendAfterClose(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "test", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + conn.Close() + // Wait for completion so we know the state is settled. + <-conn.Done() + + if err := conn.Send("after close"); err == nil { + t.Error("expected error sending after close") + } +} + +func TestBidiConnectionContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + action := NewBidiAction( + "blocking", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + <-ctx.Done() + return "", ctx.Err() + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + cancel() + + _, err = conn.Output() + if err == nil { + t.Error("expected error after context cancellation") + } +} + +func TestBidiActionRegistration(t *testing.T) { + r := registry.New() + + action := DefineBidiAction( + r, "echoAction", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + outCh <- input + } + return "done", nil + }, + ) + + if action.Name() != "echoAction" { + t.Errorf("expected name 'echoAction', got %q", action.Name()) + } + + desc := action.Desc() + + // Verify bidi metadata is set. + if bidi, ok := desc.Metadata["bidi"].(bool); !ok || !bidi { + t.Error("expected metadata[\"bidi\"] = true") + } + + // Verify registered in registry. + if r.LookupAction(desc.Key) == nil { + t.Error("expected action to be registered") + } +} + +func TestBidiActionDone(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "quick", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "finished", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + conn.Close() + <-conn.Done() + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "finished" { + t.Errorf("expected 'finished', got %q", output) + } +} + +// TestBidiRunCallbackErrorStopsAction verifies that when the streaming +// callback fails during a unary Run of a bidi action, the action's context is +// cancelled and its goroutine exits instead of leaking blocked on a stream +// write. +func TestBidiRunCallbackErrorStopsAction(t *testing.T) { + ctx := context.Background() + + fnExited := make(chan struct{}) + action := NewBidiAction( + "cb-error", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + defer close(fnExited) + for i := 0; ; i++ { + select { + case outCh <- fmt.Sprintf("chunk %d", i): + case <-ctx.Done(): + return "", ctx.Err() + } + } + }, + ) + + wantErr := errors.New("consumer failed") + _, err := action.Run(ctx, "in", func(context.Context, string) error { + return wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("Run err = %v, want %v", err, wantErr) + } + + select { + case <-fnExited: + case <-time.After(5 * time.Second): + t.Fatal("bidi function did not exit after callback error (goroutine leak)") + } +} + +// TestBidiActionPanicRecovered verifies that a panic in a bidi function is +// recovered and reported as an error rather than crashing the process, on +// both the connection and unary Run paths. +func TestBidiActionPanicRecovered(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "panicky", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + panic("boom") + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + if _, err := conn.Output(); err == nil || !strings.Contains(err.Error(), "panic in bidi action") { + t.Errorf("Output err = %v, want panic error", err) + } + + if _, err := action.Run(ctx, "in", nil); err == nil || !strings.Contains(err.Error(), "panic in bidi action") { + t.Errorf("Run err = %v, want panic error", err) + } +} + +// TestBidiActionClosingOutChIsError verifies that a bidi function closing its +// output channel (which the framework owns) surfaces as an error instead of +// crashing the process, on both the connection and unary Run paths. +func TestBidiActionClosingOutChIsError(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "closer", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + close(outCh) + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + if _, err := conn.Output(); err == nil || !strings.Contains(err.Error(), "closed its output channel") { + t.Errorf("Output err = %v, want closed-output-channel error", err) + } + + if _, err := action.Run(ctx, "in", nil); err == nil || !strings.Contains(err.Error(), "closed its output channel") { + t.Errorf("Run err = %v, want closed-output-channel error", err) + } +} + +// TestBidiReceiveBreakDoesNotCancelSession verifies that breaking out of a +// Receive loop stops consumption without aborting the session: iteration can +// resume and the final output remains available. +func TestBidiReceiveBreakDoesNotCancelSession(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "resumable", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for i := range 3 { + select { + case outCh <- fmt.Sprintf("chunk %d", i): + case <-ctx.Done(): + return "", ctx.Err() + } + } + for range inCh { + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + var got []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + got = append(got, chunk) + break // Early break must not abort the session. + } + conn.Close() + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + got = append(got, chunk) + } + if len(got) != 3 { + t.Errorf("got %d chunks total, want 3: %v", len(got), got) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out != "done" { + t.Errorf("output = %q, want %q", out, "done") + } +} + +// TestBidiConnectionCancel verifies that Cancel aborts the session. +func TestBidiConnectionCancel(t *testing.T) { + action := NewBidiAction( + "cancellable", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + <-ctx.Done() + return "", ctx.Err() + }, + ) + + conn, err := action.StreamBidi(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + conn.Cancel() + if _, err := conn.Output(); err == nil { + t.Error("expected error after Cancel") + } + conn.Cancel() // Safe to call again. +} + +// TestBidiOutputAfterCompletionNotCancelled verifies that after a normal +// completion, Output returns the result and Receive ends cleanly even though +// the connection context is released on completion. Looped to exercise the +// completion/cancellation race. +func TestBidiOutputAfterCompletionNotCancelled(t *testing.T) { + ctx := context.Background() + + for range 50 { + action := NewBidiAction( + "completes", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "done", nil + }, + ) + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out != "done" { + t.Errorf("output = %q, want %q", out, "done") + } + for _, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive after completion: %v", err) + } + } + } +} + +// TestBidiJSONConnSendValidatesChunks verifies that the JSON transport path +// validates every inbound chunk against the action's input schema and that an +// invalid chunk fails the session, matching the JS runtime and the one-shot +// path (where invalid input fails the call). +func TestBidiJSONConnSendValidatesChunks(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "typed-in", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var n int + for { + select { + case _, ok := <-inCh: + if !ok { + return fmt.Sprintf("got %d", n), nil + } + n++ + case <-ctx.Done(): + return "", ctx.Err() + } + } + }, + ) + + t.Run("valid chunk delivered", func(t *testing.T) { + conn, err := action.StreamBidiJSON(ctx, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Send(json.RawMessage(`"ok"`)); err != nil { + t.Errorf("Send valid chunk: %v", err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + out, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if string(out) != `"got 1"` { + t.Errorf("output = %s, want %q", out, `"got 1"`) + } + }) + + t.Run("invalid chunk fails the session", func(t *testing.T) { + conn, err := action.StreamBidiJSON(ctx, nil) + if err != nil { + t.Fatal(err) + } + serr := conn.Send(json.RawMessage(`123`)) + if serr == nil { + t.Fatal("expected validation error for non-string chunk") + } + if gerr, ok := serr.(*GenkitError); !ok || gerr.Status != INVALID_ARGUMENT { + t.Errorf("Send err = %v, want INVALID_ARGUMENT GenkitError", serr) + } + // The validation error is the session's terminal error. + if _, oerr := conn.Output(); oerr == nil || !strings.Contains(oerr.Error(), "invalid stream chunk") { + t.Errorf("Output err = %v, want invalid-chunk error", oerr) + } + }) + + t.Run("null chunk validated like any payload", func(t *testing.T) { + conn, err := action.StreamBidiJSON(ctx, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Send(json.RawMessage(`null`)); err == nil { + t.Error("expected validation error for null chunk") + } + }) +} + +// TestBidiOutputSchemaValidatedOnConnection verifies that a session's final +// output is validated against the action's OutputSchema, mirroring the unary +// path. +func TestBidiOutputSchemaValidatedOnConnection(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "bad-output", api.ActionTypeCustom, + &BidiActionOptions{OutputSchema: map[string]any{"type": "string"}}, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (int, error) { + for range inCh { + } + return 42, nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + conn.Close() + if _, err := conn.Output(); err == nil || !strings.Contains(err.Error(), "invalid output") { + t.Errorf("Output err = %v, want invalid-output error", err) + } +} + +// TestBidiJSONConnReceiveMarshalErrorAbortsSession verifies that a stream +// chunk that cannot be marshaled aborts the session instead of leaving the +// action running with no consumer. +func TestBidiJSONConnReceiveMarshalErrorAbortsSession(t *testing.T) { + ctx := context.Background() + + fnExited := make(chan struct{}) + action := NewBidiAction( + "nan-stream", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- float64) (string, error) { + defer close(fnExited) + select { + case outCh <- math.NaN(): // json.Marshal fails on NaN. + case <-ctx.Done(): + return "", ctx.Err() + } + <-ctx.Done() + return "", ctx.Err() + }, + ) + + conn, err := action.StreamBidiJSON(ctx, nil) + if err != nil { + t.Fatal(err) + } + var gotErr error + for _, rerr := range conn.Receive() { + if rerr != nil { + gotErr = rerr + break + } + } + if gotErr == nil { + t.Fatal("expected marshal error from Receive") + } + select { + case <-fnExited: + case <-time.After(5 * time.Second): + t.Fatal("session not aborted after marshal error (goroutine leak)") + } + // The marshal error is the session's terminal error, not a bare + // cancellation. + if _, oerr := conn.Output(); oerr == nil || !strings.Contains(oerr.Error(), "unsupported value") { + t.Errorf("Output err = %v, want marshal error", oerr) + } +} + +// TestBidiInvalidChunkFailsCtxObliviousSession verifies that the poison cause +// is the session's terminal error even when the action never observes the +// cancellation and returns a nil error after its input closes. +func TestBidiInvalidChunkFailsCtxObliviousSession(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "oblivious", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "done", nil + }, + ) + + conn, err := action.StreamBidiJSON(ctx, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Send(json.RawMessage(`123`)); err == nil { + t.Fatal("expected validation error for non-string chunk") + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if _, oerr := conn.Output(); oerr == nil || !strings.Contains(oerr.Error(), "invalid stream chunk") { + t.Errorf("Output err = %v, want invalid-chunk error overriding the nil result", oerr) + } +} + +// TestBidiSendAfterCompletionFails verifies that Send fails deterministically +// once the action has completed, even when the input channel was never closed +// and still has buffer space. +func TestBidiSendAfterCompletionFails(t *testing.T) { + action := NewBidiAction( + "one-read", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + <-inCh + return "done", nil + }, + ) + + conn, err := action.StreamBidi(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + if err := conn.Send("only"); err != nil { + t.Fatal(err) + } + <-conn.Done() + + // Without the completion pre-check, the blocking select races the free + // buffer slot against the closed done/ctx channels and ~1/3 of sends + // would report success for a message nothing will ever read. + for i := range 25 { + if err := conn.Send("late"); err == nil { + t.Fatalf("Send %d after completion returned nil, want error", i) + } + } +} + +// TestBidiSessionWrapperPanicNotMislabeled verifies that a panic escaping the +// session wrapper (outside the user function) is reported as a panic, not +// misattributed to the action closing its output channel. +func TestBidiSessionWrapperPanicNotMislabeled(t *testing.T) { + // An unregistered action with a $ref output schema: schema resolution + // dereferences the nil registry after the function returns, panicking + // inside the session wrapper. + action := NewBidiAction( + "ref-output", api.ActionTypeCustom, + &BidiActionOptions{OutputSchema: SchemaRef("missing")}, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + conn.Close() + _, oerr := conn.Output() + if oerr == nil { + t.Fatal("expected error from wrapper panic") + } + if strings.Contains(oerr.Error(), "closed its output channel") { + t.Errorf("Output err = %v; wrapper panic mislabeled as output-channel close", oerr) + } + if !strings.Contains(oerr.Error(), "panic in bidi session") { + t.Errorf("Output err = %v, want panic-in-bidi-session error", oerr) + } +} + +// TestBidiRunCallbackPanicReleasesAction verifies that a panicking stream +// callback on the unary path does not strand the action goroutine blocked on +// a stream write. +func TestBidiRunCallbackPanicReleasesAction(t *testing.T) { + ctx := context.Background() + + fnExited := make(chan struct{}) + action := NewBidiAction( + "cb-panic", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + defer close(fnExited) + for i := 0; ; i++ { + select { + case outCh <- fmt.Sprintf("chunk %d", i): + case <-ctx.Done(): + return "", ctx.Err() + } + } + }, + ) + + func() { + defer func() { + if recover() == nil { + t.Error("expected callback panic to propagate to the caller") + } + }() + _, _ = action.Run(ctx, "in", func(context.Context, string) error { + panic("callback boom") + }) + }() + + select { + case <-fnExited: + case <-time.After(5 * time.Second): + t.Fatal("bidi function did not exit after callback panic (goroutine leak)") + } +} + +// TestBidiJSONConnEmptyChunkValidated verifies that an absent chunk payload +// is validated like the unary path validates its decoded input, rather than +// silently delivering an unchecked zero value. +func TestBidiJSONConnEmptyChunkValidated(t *testing.T) { + ctx := context.Background() + + type msg struct { + Name string `json:"name,omitempty"` + } + schema := map[string]any{ + "type": "object", + "properties": map[string]any{"name": map[string]any{"type": "string"}}, + "required": []any{"name"}, + } + + action := NewBidiAction( + "required-in", api.ActionTypeCustom, + &BidiActionOptions{InputSchema: schema}, + func(ctx context.Context, _ struct{}, inCh <-chan msg, outCh chan<- string) (string, error) { + for { + select { + case _, ok := <-inCh: + if !ok { + return "done", nil + } + case <-ctx.Done(): + return "", ctx.Err() + } + } + }, + ) + + conn, err := action.StreamBidiJSON(ctx, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Send(nil); err == nil { + t.Error("expected validation error for empty chunk against a schema with required fields") + } +} + +// TestActionDescSchemaSentinels verifies that the struct{} sentinel type +// parameters do not leak inferred schemas into action descriptors: only +// streaming actions advertise a streamSchema and only bidi actions with a +// real Init type advertise an initSchema. +func TestActionDescSchemaSentinels(t *testing.T) { + plain := NewAction("plain-desc", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, in string) (string, error) { return in, nil }) + if got := plain.Desc().StreamSchema; got != nil { + t.Errorf("non-streaming action StreamSchema = %v, want nil", got) + } + if got := plain.Desc().InitSchema; got != nil { + t.Errorf("non-streaming action InitSchema = %v, want nil", got) + } + + streaming := NewStreamingAction("streaming-desc", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, in string, cb StreamCallback[string]) (string, error) { return in, nil }) + if got := streaming.Desc().StreamSchema; got == nil { + t.Error("streaming action StreamSchema = nil, want schema") + } + + noInit := NewBidiAction("bidi-noinit-desc", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + return "", nil + }) + if got := noInit.Desc().InitSchema; got != nil { + t.Errorf("bidi action without init InitSchema = %v, want nil", got) + } + if got := noInit.Desc().StreamSchema; got == nil { + t.Error("bidi action StreamSchema = nil, want schema") + } + + type Config struct{ Prefix string } + withInit := NewBidiAction("bidi-init-desc", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + return "", nil + }) + if got := withInit.Desc().InitSchema; got == nil { + t.Error("bidi action with init InitSchema = nil, want schema") + } +} + +// TestBidiEchoStress exercises many concurrent sessions with many messages +// each. Run it with -race and GOMAXPROCS=1 to catch scheduling-dependent +// bugs in the connection's channel handling. +func TestBidiEchoStress(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "stress-echo", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan int, outCh chan<- int) (int, error) { + var sum int + for v := range inCh { + sum += v + select { + case outCh <- v * 2: + case <-ctx.Done(): + return 0, ctx.Err() + } + } + return sum, nil + }, + ) + + const sessions = 16 + const messages = 100 + + var wg sync.WaitGroup + for s := range sessions { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Error(err) + return + } + go func() { + for i := range messages { + if err := conn.Send(i); err != nil { + t.Error(err) + return + } + } + conn.Close() + }() + var count int + for chunk, err := range conn.Receive() { + if err != nil { + t.Error(err) + return + } + _ = chunk + count++ + } + if count != messages { + t.Errorf("session %d: got %d chunks, want %d", s, count, messages) + } + out, err := conn.Output() + if err != nil { + t.Error(err) + return + } + if want := messages * (messages - 1) / 2; out != want { + t.Errorf("session %d: output %d, want %d", s, out, want) + } + }() + } + wg.Wait() +} + +// TestResolveBidiActionFor verifies typed round-trip resolution of a bidi +// action from the registry. +func TestResolveBidiActionFor(t *testing.T) { + ctx := context.Background() + r := registry.New() + + type Config struct{ Prefix string } + + DefineBidiAction(r, "resolvable-bidi", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + var out string + for in := range inCh { + out = cfg.Prefix + in + } + return out, nil + }) + + resolved := ResolveBidiActionFor[string, string, string, Config](r, api.ActionTypeCustom, "resolvable-bidi") + if resolved == nil { + t.Fatal("ResolveBidiActionFor returned nil") + } + got, err := resolved.RunWithInit(ctx, Config{Prefix: ">> "}, "hello", nil) + if err != nil { + t.Fatalf("RunWithInit: %v", err) + } + if got != ">> hello" { + t.Errorf("output = %q, want %q", got, ">> hello") + } + + if missing := ResolveBidiActionFor[string, string, string, Config](r, api.ActionTypeCustom, "nope"); missing != nil { + t.Errorf("expected nil for missing action, got %v", missing) + } +} diff --git a/go/core/flow.go b/go/core/flow.go index 3bf78031ed..d2d97e008a 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -26,17 +26,17 @@ import ( "github.com/firebase/genkit/go/internal/base" ) -// A Flow is a user-defined Action. A Flow[In, Out, StreamOut, StreamIn] represents a function from In to Out. -// The StreamOut parameter is for flows that support streaming: providing their results incrementally. The StreamIn parameter is for bidi flows. -type Flow[In, Out, StreamOut, StreamIn any] struct { - *Action[In, Out, StreamOut, StreamIn] +// A Flow is a user-defined Action. A Flow[In, Out, Stream] represents a function from In to Out. +// The Stream parameter is for flows that support streaming: providing their results incrementally. +type Flow[In, Out, Stream any] struct { + *Action[In, Out, Stream] } // StreamingFlowValue is either a streamed value or a final output of a flow. -type StreamingFlowValue[Out, StreamOut any] struct { +type StreamingFlowValue[Out, Stream any] struct { Done bool - Output Out // valid if Done is true - Stream StreamOut // valid if Done is false + Output Out // valid if Done is true + Stream Stream // valid if Done is false } // flowContextKey is a context key that indicates whether the current context is a flow context. @@ -48,8 +48,8 @@ type flowContext struct { } // NewFlow creates a Flow that runs fn without registering it. fn takes an input of type In and returns an output of type Out. -func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { - return &Flow[In, Out, struct{}, struct{}]{NewAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { +func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { + return &Flow[In, Out, struct{}]{NewAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { fc := &flowContext{ flowName: name, } @@ -59,31 +59,21 @@ func NewFlow[In, Out any](name string, fn Func[In, Out]) *Flow[In, Out, struct{} } // NewStreamingFlow creates a streaming Flow that runs fn without registering it. -func NewStreamingFlow[In, Out, StreamOut any](name string, fn StreamingFunc[In, Out, StreamOut]) *Flow[In, Out, StreamOut, struct{}] { - return &Flow[In, Out, StreamOut, struct{}]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, StreamOut) error) (Out, error) { +func NewStreamingFlow[In, Out, Stream any](name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] { + return &Flow[In, Out, Stream]{NewStreamingAction(name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { fc := &flowContext{ flowName: name, } ctx = flowContextKey.NewContext(ctx, fc) if cb == nil { - cb = func(context.Context, StreamOut) error { return nil } + cb = func(context.Context, Stream) error { return nil } } return fn(ctx, input, cb) })} } -// NewBidiFlow creates a bidirectional streaming Flow without registering it. -// Flow context is injected so that [Run] works inside the bidi function. -func NewBidiFlow[In, Out, StreamOut, StreamIn any](name string, fn BidiFunc[In, Out, StreamOut, StreamIn]) *Flow[In, Out, StreamOut, StreamIn] { - wrapped := func(ctx context.Context, in In, inCh <-chan StreamIn, outCh chan<- StreamOut) (Out, error) { - ctx = flowContextKey.NewContext(ctx, &flowContext{flowName: name}) - return fn(ctx, in, inCh, outCh) - } - return &Flow[In, Out, StreamOut, StreamIn]{NewBidiAction(name, api.ActionTypeFlow, nil, wrapped)} -} - // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. -func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { +func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { f := NewFlow(name, fn) f.Register(r) return f @@ -92,26 +82,18 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. // // fn takes an input of type In and returns an output of type Out, optionally -// streaming values of type StreamOut incrementally by invoking a callback. +// streaming values of type Stream incrementally by invoking a callback. // // If the function supports streaming and the callback is non-nil, it should // stream the results by invoking the callback periodically, ultimately returning // with a final return value that includes all the streamed data. // Otherwise, it should ignore the callback and just return a result. -func DefineStreamingFlow[In, Out, StreamOut any](r api.Registry, name string, fn StreamingFunc[In, Out, StreamOut]) *Flow[In, Out, StreamOut, struct{}] { +func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] { f := NewStreamingFlow(name, fn) f.Register(r) return f } -// DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. -// Flow context is injected so that [Run] works inside the bidi function. -func DefineBidiFlow[In, Out, StreamOut, StreamIn any](r api.Registry, name string, fn BidiFunc[In, Out, StreamOut, StreamIn]) *Flow[In, Out, StreamOut, StreamIn] { - f := NewBidiFlow(name, fn) - f.Register(r) - return f -} - // Run runs the function f in the context of the current flow // and returns what f returns. // It returns an error if no flow is active. @@ -140,7 +122,7 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out } // Run runs the flow in the context of another flow. -func (f *Flow[In, Out, StreamOut, StreamIn]) Run(ctx context.Context, input In) (Out, error) { +func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) { return f.Action.Run(ctx, input, nil) } @@ -156,17 +138,17 @@ func (f *Flow[In, Out, StreamOut, StreamIn]) Run(ctx context.Context, input In) // again. // // Otherwise the Stream field of the passed [StreamingFlowValue] holds a streamed result. -func (f *Flow[In, Out, StreamOut, StreamIn]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, StreamOut], error) bool) { - return func(yield func(*StreamingFlowValue[Out, StreamOut], error) bool) { +func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, Stream], error) bool) { + return func(yield func(*StreamingFlowValue[Out, Stream], error) bool) { done := false - cb := func(ctx context.Context, s StreamOut) error { + cb := func(ctx context.Context, s Stream) error { if done { return errStop } if ctx.Err() != nil { return ctx.Err() } - if !yield(&StreamingFlowValue[Out, StreamOut]{Stream: s}, nil) { + if !yield(&StreamingFlowValue[Out, Stream]{Stream: s}, nil) { done = true return errStop } @@ -180,7 +162,7 @@ func (f *Flow[In, Out, StreamOut, StreamIn]) Stream(ctx context.Context, input I if err != nil { yield(nil, err) } else { - yield(&StreamingFlowValue[Out, StreamOut]{Done: true, Output: output}, nil) + yield(&StreamingFlowValue[Out, Stream]{Done: true, Output: output}, nil) } } } diff --git a/go/core/flow_test.go b/go/core/flow_test.go index 5810ac4625..e3c3e6b463 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -18,12 +18,9 @@ package core import ( "context" - "fmt" "slices" - "strings" "testing" - "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" ) @@ -72,7 +69,7 @@ func TestRunFlow(t *testing.T) { func TestFlowNameFromContext(t *testing.T) { r := registry.New() - flows := []*Flow[struct{}, string, struct{}, struct{}]{ + flows := []*Flow[struct{}, string, struct{}]{ DefineFlow(r, "DefineFlow", func(ctx context.Context, _ struct{}) (string, error) { return FlowNameFromContext(ctx), nil }), @@ -260,304 +257,3 @@ func TestFlowNameFromContextOutsideFlow(t *testing.T) { } }) } - -func TestBidiActionEcho(t *testing.T) { - ctx := context.Background() - - // In=struct{} (no initial data), Out=string, OutStream=string, InStream=string - action := NewBidiAction( - "echo", api.ActionTypeCustom, nil, - func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { - var count int - for input := range inCh { - count++ - outCh <- fmt.Sprintf("echo: %s", input) - } - return fmt.Sprintf("processed %d messages", count), nil - }, - ) - - conn, err := action.StreamBidi(ctx, struct{}{}) - if err != nil { - t.Fatal(err) - } - - // With unbuffered channels, we must send and receive concurrently. - go func() { - conn.Send("hello") - conn.Send("world") - conn.Close() - }() - - var chunks []string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatal(err) - } - chunks = append(chunks, chunk) - } - - if len(chunks) != 2 { - t.Fatalf("expected 2 chunks, got %d: %v", len(chunks), chunks) - } - if chunks[0] != "echo: hello" { - t.Errorf("expected 'echo: hello', got %q", chunks[0]) - } - if chunks[1] != "echo: world" { - t.Errorf("expected 'echo: world', got %q", chunks[1]) - } - - output, err := conn.Output() - if err != nil { - t.Fatal(err) - } - if output != "processed 2 messages" { - t.Errorf("expected 'processed 2 messages', got %q", output) - } -} - -func TestBidiActionWithConfig(t *testing.T) { - ctx := context.Background() - - type Config struct { - Prefix string - } - - // In=Config (initial config), Out=string, OutStream=string, InStream=string - action := NewBidiAction( - "prefixed", api.ActionTypeCustom, nil, - func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { - for input := range inCh { - outCh <- fmt.Sprintf("%s: %s", cfg.Prefix, input) - } - return "done", nil - }, - ) - - conn, err := action.StreamBidi(ctx, Config{Prefix: "INFO"}) - if err != nil { - t.Fatal(err) - } - - go func() { - conn.Send("test message") - conn.Close() - }() - - var chunks []string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatal(err) - } - chunks = append(chunks, chunk) - } - - if len(chunks) != 1 || chunks[0] != "INFO: test message" { - t.Errorf("unexpected chunks: %v", chunks) - } -} - -func TestBidiConnectionSendAfterClose(t *testing.T) { - ctx := context.Background() - - action := NewBidiAction( - "test", api.ActionTypeCustom, nil, - func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { - for range inCh { - } - return "", nil - }, - ) - - conn, err := action.StreamBidi(ctx, struct{}{}) - if err != nil { - t.Fatal(err) - } - - conn.Close() - // Wait for completion so we know the state is settled. - <-conn.Done() - - if err := conn.Send("after close"); err == nil { - t.Error("expected error sending after close") - } -} - -func TestBidiConnectionContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - action := NewBidiAction( - "blocking", api.ActionTypeCustom, nil, - func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { - <-ctx.Done() - return "", ctx.Err() - }, - ) - - conn, err := action.StreamBidi(ctx, struct{}{}) - if err != nil { - t.Fatal(err) - } - - cancel() - - _, err = conn.Output() - if err == nil { - t.Error("expected error after context cancellation") - } -} - -func TestBidiFlowRegistration(t *testing.T) { - r := registry.New() - - flow := DefineBidiFlow( - r, "echoFlow", - func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { - for input := range inCh { - outCh <- input - } - return "done", nil - }, - ) - - if flow.Name() != "echoFlow" { - t.Errorf("expected name 'echoFlow', got %q", flow.Name()) - } - - desc := flow.Desc() - if desc.Type != api.ActionTypeFlow { - t.Errorf("expected type %q, got %q", api.ActionTypeFlow, desc.Type) - } - - // Verify bidi metadata is set. - if bidi, ok := desc.Metadata["bidi"].(bool); !ok || !bidi { - t.Error("expected metadata[\"bidi\"] = true") - } - - // Verify registered in registry. - action := r.LookupAction(desc.Key) - if action == nil { - t.Error("expected action to be registered") - } -} - -func TestBidiFlowEcho(t *testing.T) { - r := registry.New() - ctx := context.Background() - - flow := DefineBidiFlow( - r, "echoFlow", - func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { - var count int - for input := range inCh { - count++ - outCh <- fmt.Sprintf("echo: %s", input) - } - return fmt.Sprintf("processed %d", count), nil - }, - ) - - conn, err := flow.StreamBidi(ctx, struct{}{}) - if err != nil { - t.Fatal(err) - } - - go func() { - conn.Send("a") - conn.Send("b") - conn.Close() - }() - - var chunks []string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatal(err) - } - chunks = append(chunks, chunk) - } - - if len(chunks) != 2 { - t.Fatalf("expected 2 chunks, got %d", len(chunks)) - } - - output, err := conn.Output() - if err != nil { - t.Fatal(err) - } - if output != "processed 2" { - t.Errorf("expected 'processed 2', got %q", output) - } -} - -func TestBidiFlowCoreRunWorks(t *testing.T) { - r := registry.New() - ctx := context.Background() - - flow := DefineBidiFlow( - r, "withSteps", - func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { - for input := range inCh { - // core.Run should work inside a BidiFlow. - result, err := Run(ctx, "uppercase", func() (string, error) { - return strings.ToUpper(input), nil - }) - if err != nil { - return "", err - } - outCh <- result - } - return "done", nil - }, - ) - - conn, err := flow.StreamBidi(ctx, struct{}{}) - if err != nil { - t.Fatal(err) - } - - go func() { - conn.Send("hello") - conn.Close() - }() - - var chunks []string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatal(err) - } - chunks = append(chunks, chunk) - } - - if len(chunks) != 1 || chunks[0] != "HELLO" { - t.Errorf("expected [HELLO], got %v", chunks) - } -} - -func TestBidiActionDone(t *testing.T) { - ctx := context.Background() - - action := NewBidiAction( - "quick", api.ActionTypeCustom, nil, - func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { - for range inCh { - } - return "finished", nil - }, - ) - - conn, err := action.StreamBidi(ctx, struct{}{}) - if err != nil { - t.Fatal(err) - } - - conn.Close() - <-conn.Done() - - output, err := conn.Output() - if err != nil { - t.Fatal(err) - } - if output != "finished" { - t.Errorf("expected 'finished', got %q", output) - } -} diff --git a/go/core/schemas.config b/go/core/schemas.config index fa0c5ef715..7382031047 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1198,6 +1198,7 @@ ReflectionRunActionParams is the payload for the "runAction" request sent by the CLI manager to execute an action on the runtime. . ReflectionRunActionParams.input type json.RawMessage +ReflectionRunActionParams.init type json.RawMessage ReflectionRunActionParams.context type json.RawMessage ReflectionRunActionParams.telemetryLabels type json.RawMessage ReflectionRunActionParams.stream doc @@ -1236,6 +1237,7 @@ ReflectionStreamChunkParams.requestId name RequestID ReflectionStreamChunkParams.requestId doc ID of the JSON-RPC request this chunk belongs to. . +ReflectionStreamChunkParams.chunk type json.RawMessage ReflectionStreamChunkParams.chunk doc The streamed data chunk. . @@ -1284,9 +1286,11 @@ ReflectionListActionsResponse omit ReflectionSendInputStreamChunkParams pkg genkit ReflectionSendInputStreamChunkParams doc ReflectionSendInputStreamChunkParams is the payload for the -"sendInputStreamChunk" notification (bidirectional streaming, not yet implemented). +"sendInputStreamChunk" notification used to deliver one inbound chunk +to a running bidirectional action. . ReflectionSendInputStreamChunkParams.requestId name RequestID +ReflectionSendInputStreamChunkParams.chunk type json.RawMessage ReflectionEndInputStreamParams pkg genkit ReflectionEndInputStreamParams doc diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index 727625f75a..88e1950a59 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -169,6 +169,11 @@ type SpanMetadata struct { TelemetryLabels map[string]string // Metadata are genkit-specific metadata with automatic "genkit:metadata:" prefix Metadata map[string]string + // Init is the initialization data supplied to the action, recorded as the + // "genkit:init" span attribute when non-nil. It is kept separate from the + // span input so tooling can distinguish per-call input from session + // initialization data. + Init any } // RunInNewSpan runs f on input in a new span with the provided metadata. @@ -200,6 +205,7 @@ func RunInNewSpan[I, O any]( sm := &spanMetadata{ Name: metadata.Name, Input: input, + Init: metadata.Init, IsRoot: isRoot, Type: metadata.Type, Subtype: metadata.Subtype, @@ -329,6 +335,7 @@ type spanMetadata struct { IsRoot bool IsFailureSource bool // whether this span is the source of a failure Input any + Init any // initialization data for the action, if any Output any Error string // error message if State is spanStateError Path string // annotated path with type information @@ -347,6 +354,10 @@ func (sm *spanMetadata) attributes() []attribute.KeyValue { attribute.String("genkit:path", sm.Path), } + if sm.Init != nil { + kvs = append(kvs, attribute.String("genkit:init", base.JSONString(sm.Init))) + } + if sm.Output != nil { kvs = append(kvs, attribute.String("genkit:output", base.JSONString(sm.Output))) } diff --git a/go/core/tracing/tracing_test.go b/go/core/tracing/tracing_test.go index 6d6f1dd0d3..573afca475 100644 --- a/go/core/tracing/tracing_test.go +++ b/go/core/tracing/tracing_test.go @@ -56,6 +56,28 @@ func TestSpanMetadata(t *testing.T) { } } +func TestSpanMetadataInit(t *testing.T) { + sm := &spanMetadata{ + Name: "name", + State: spanStateSuccess, + Path: "parent/name", + Input: "in", + Init: map[string]string{"prefix": "PFX:"}, + } + + got := sm.attributes() + want := []attribute.KeyValue{ + attribute.String("genkit:name", "name"), + attribute.String("genkit:state", "success"), + attribute.String("genkit:input", `"in"`), + attribute.String("genkit:path", "parent/name"), + attribute.String("genkit:init", `{"prefix":"PFX:"}`), + } + if !slices.Equal(got, want) { + t.Errorf("\ngot %v\nwant %v", got, want) + } +} + func TestSpanMetadataWithTypeAndSubtype(t *testing.T) { const ( testInput = "test input" diff --git a/go/genkit/gen.go b/go/genkit/gen.go index 5cd27b27cb..753a491847 100644 --- a/go/genkit/gen.go +++ b/go/genkit/gen.go @@ -79,6 +79,8 @@ type ReflectionRegisterParams struct { type ReflectionRunActionParams struct { // Additional runtime context data (ex. auth context data). Context json.RawMessage `json:"context,omitempty"` + // Initialization parameters to establish long running session states. + Init json.RawMessage `json:"init,omitempty"` // An input with the type that this action expects. Input json.RawMessage `json:"input,omitempty"` // Action key that consists of the action type and ID. @@ -110,17 +112,18 @@ type ReflectionRunActionStateParamsState struct { } // ReflectionSendInputStreamChunkParams is the payload for the -// "sendInputStreamChunk" notification (bidirectional streaming, not yet implemented). +// "sendInputStreamChunk" notification used to deliver one inbound chunk +// to a running bidirectional action. type ReflectionSendInputStreamChunkParams struct { - Chunk any `json:"chunk,omitempty"` - RequestID string `json:"requestId,omitempty"` + Chunk json.RawMessage `json:"chunk,omitempty"` + RequestID string `json:"requestId,omitempty"` } // ReflectionStreamChunkParams is the payload for the "streamChunk" // notification sent by the runtime during a streaming runAction request. type ReflectionStreamChunkParams struct { // The streamed data chunk. - Chunk any `json:"chunk,omitempty"` + Chunk json.RawMessage `json:"chunk,omitempty"` // ID of the JSON-RPC request this chunk belongs to. RequestID string `json:"requestId,omitempty"` } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index e6e0e76c4b..8485d44025 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -325,7 +325,7 @@ func RegisterAction(g *Genkit, action api.Registerable) { // // handle error // } // fmt.Println(result) // Output: Hello, World! -func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}, struct{}] { +func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}] { return core.DefineFlow(g.reg, name, fn) } @@ -376,57 +376,20 @@ func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *cor // fmt.Println("Stream Chunk:", result.Stream) // Outputs: 1, 2, 3, 4, 5 // } // } -func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream, struct{}] { +func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { return core.DefineStreamingFlow(g.reg, name, fn) } -// DefineBidiFlow defines a bidirectional streaming flow, registers it as a [core.Action] of type Flow, -// and returns a [core.Flow] capable of bidirectional streaming. -// -// The provided function `fn` receives an initial input of type `In`, reads -// incoming stream messages of type `StreamIn` from an input channel, and writes -// outgoing stream messages of type `StreamOut` to an output channel. It returns -// a final output of type `Out` when complete. -// -// Example: -// -// chatFlow := genkit.DefineBidiFlow(g, "chat", -// func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { -// var count int -// for input := range inCh { -// count++ -// outCh <- fmt.Sprintf("reply: %s", input) -// } -// return fmt.Sprintf("processed %d messages", count), nil -// }, -// ) -// -// // Start a bidi connection: -// conn, err := chatFlow.StreamBidi(ctx, struct{}{}) -// if err != nil { -// // handle error -// } -// -// // Send messages concurrently: -// go func() { -// conn.Send("hello") -// conn.Send("world") -// conn.Close() -// }() -// -// // Receive streamed responses: -// for chunk, err := range conn.Receive() { -// if err != nil { -// // handle error -// } -// fmt.Println(chunk) // Outputs: "reply: hello", "reply: world" -// } -// -// // Get the final output: -// output, err := conn.Output() -// fmt.Println(output) // Output: "processed 2 messages" -func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn core.BidiFunc[In, Out, StreamOut, StreamIn]) *core.Flow[In, Out, StreamOut, StreamIn] { - return core.DefineBidiFlow(g.reg, name, fn) +// NewFlow creates a [core.Flow] without registering it as an action. +// To register the flow later, call [RegisterAction]. +func NewFlow[In, Out any](name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}] { + return core.NewFlow(name, fn) +} + +// NewStreamingFlow creates a streaming [core.Flow] without registering it as an action. +// To register the flow later, call [RegisterAction]. +func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { + return core.NewStreamingFlow(name, fn) } // Run executes the given function `fn` within the context of the current flow run, diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 915f40b193..9ddff1fbc0 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -37,6 +37,7 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/internal/base" ) type streamingCallback[Stream any] = func(context.Context, Stream) error @@ -349,6 +350,7 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res var body struct { Key string `json:"key"` Input json.RawMessage `json:"input"` + Init json.RawMessage `json:"init"` Context json.RawMessage `json:"context"` TelemetryLabels json.RawMessage `json:"telemetryLabels"` } @@ -431,7 +433,7 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res // Attach telemetry callback to context so action can invoke it when span is created actionCtx = tracing.WithTelemetryCallback(actionCtx, telemetryCb) - resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap) + resp, err := runAction(actionCtx, g, body.Key, body.Input, body.Init, body.TelemetryLabels, cb, contextMap) // Clean up active action using the trace ID from response if resp != nil && resp.Telemetry.TraceID != "" { @@ -710,7 +712,7 @@ type errorResponse struct { Error core.ReflectionError `json:"error"` } -func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { +func runAction(ctx context.Context, g *Genkit, key string, input, init json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { action := g.reg.ResolveAction(key) if action == nil { return nil, core.NewError(core.NOT_FOUND, "action %q not found", key) @@ -729,7 +731,7 @@ func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage // Run the action and capture trace ID. We need to ensure there's a valid trace context. var traceID string output, err := func() (json.RawMessage, error) { - r, err := action.RunJSONWithTelemetry(ctx, input, cb) + r, err := runActionWithOptionalInit(ctx, action, input, init, cb) if r != nil { traceID = r.TraceId } @@ -750,6 +752,22 @@ func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage }, nil } +// runActionWithOptionalInit runs an action through its JSON surface, +// dispatching to the bidi one-shot path when init carries a value. Init on a +// non-bidi action is rejected with INVALID_ARGUMENT. Shared by the reflection +// servers and the HTTP action handler so the init-acceptance contract stays +// in one place. +func runActionWithOptionalInit(ctx context.Context, a api.Action, input, init json.RawMessage, cb streamingCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { + if base.HasJSONValue(init) { + bidi, ok := a.(api.BidiAction) + if !ok { + return nil, core.NewError(core.INVALID_ARGUMENT, "action %q does not accept init", a.Name()) + } + return bidi.RunBidiJSON(ctx, input, cb, &api.BidiSessionOptions{Init: init}) + } + return a.RunJSONWithTelemetry(ctx, input, cb) +} + // writeJSON writes a JSON-marshaled value to the response writer. func writeJSON(ctx context.Context, w http.ResponseWriter, value any) error { w.Header().Set("Content-Type", "application/json") diff --git a/go/genkit/reflection_test.go b/go/genkit/reflection_test.go index b6a88492cb..83aba3f9d8 100644 --- a/go/genkit/reflection_test.go +++ b/go/genkit/reflection_test.go @@ -409,9 +409,11 @@ func TestEarlyTraceIDTransmission(t *testing.T) { // Backwards compatability t.Run("trace ID in headers matches body", func(t *testing.T) { - // Reset channels for this subtest - actionStarted = make(chan struct{}) - actionCanProceed = make(chan struct{}) + // Fresh channels for this subtest. Reassigning the outer variables + // would race with the first subtest's action closure, which captures + // them by reference and may still be running. + actionStarted := make(chan struct{}) + actionCanProceed := make(chan struct{}) // Re-register action for this subtest core.DefineAction(g.reg, "test/slow2", api.ActionTypeCustom, nil, nil, @@ -649,3 +651,81 @@ func TestCancelActionEndpoint(t *testing.T) { } }) } + +// TestRunActionWithInit verifies that the v1 reflection API forwards the +// request's init payload to the action (matching the JS runtime) and that +// init on an action without init support fails loudly. +func TestRunActionWithInit(t *testing.T) { + g := Init(context.Background()) + + tc := tracing.NewTestOnlyTelemetryClient() + tracing.WriteTelemetryImmediate(tc) + + type initConfig struct { + Prefix string `json:"prefix"` + } + core.DefineBidiAction(g.reg, "test/bidi-prefix", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg initConfig, inCh <-chan string, outCh chan<- string) (string, error) { + var out string + for chunk := range inCh { + out = cfg.Prefix + chunk + } + return out, nil + }) + core.DefineAction(g.reg, "test/no-init", api.ActionTypeCustom, nil, nil, inc) + + s := &reflectionServer{ + Server: &http.Server{}, + activeActions: newActiveActionsMap(), + } + ts := httptest.NewServer(serveMux(g, s)) + s.Addr = strings.TrimPrefix(ts.URL, "http://") + defer ts.Close() + + t.Run("init reaches bidi action", func(t *testing.T) { + body := `{"key": "/custom/test/bidi-prefix", "input": "hello", "init": {"prefix": ">> "}}` + res, err := http.Post(ts.URL+"/api/runAction", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + b, _ := io.ReadAll(res.Body) + t.Fatalf("status = %d, body: %s", res.StatusCode, b) + } + var resp runActionResponse + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + var out string + if err := json.Unmarshal(resp.Result, &out); err != nil { + t.Fatal(err) + } + if out != ">> hello" { + t.Errorf("result = %q, want %q", out, ">> hello") + } + }) + + t.Run("init on non-bidi action is rejected", func(t *testing.T) { + body := `{"key": "/custom/test/no-init", "input": 1, "init": {"x": 1}}` + res, err := http.Post(ts.URL+"/api/runAction", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + // The v1 handler reports action errors as an error JSON body + // (matching the TS runtime), not via the HTTP status. + var resp struct { + Error *core.ReflectionError `json:"error"` + } + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + if resp.Error == nil { + t.Fatal("expected error response for init on non-bidi action") + } + if !strings.Contains(resp.Error.Message, "does not accept init") { + t.Errorf("error message = %q, want mention of init rejection", resp.Error.Message) + } + }) +} diff --git a/go/genkit/reflection_v2.go b/go/genkit/reflection_v2.go index 1b20386c28..158dc47134 100644 --- a/go/genkit/reflection_v2.go +++ b/go/genkit/reflection_v2.go @@ -31,6 +31,7 @@ import ( "github.com/coder/websocket" "github.com/coder/websocket/wsjson" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" ) @@ -120,6 +121,13 @@ type reflectionServerV2 struct { pendingMu sync.Mutex pending map[string]chan pendingResponse requestSeq atomic.Uint64 + + // bidiSessions tracks in-flight bidi runAction calls so that + // sendInputStreamChunk/endInputStream notifications can be routed to them. + // Sessions are pre-registered in readLoop so that chunks arriving before + // the action has finished initializing are buffered rather than dropped. + bidiMu sync.Mutex + bidiSessions map[string]*bidiSession } // reflectionServerV2Options configures the V2 reflection client. @@ -148,6 +156,7 @@ func startReflectionServerV2(ctx context.Context, g *Genkit, opts reflectionServ activeActions: newActiveActionsMap(), runtimeID: runtimeID, pending: map[string]chan pendingResponse{}, + bidiSessions: map[string]*bidiSession{}, } // Initial connect so startup errors surface via errCh. Reconnects after @@ -169,7 +178,12 @@ func (s *reflectionServerV2) connect(ctx context.Context) error { if err != nil { return err } + // Handler goroutines (and bidi runs surviving a disconnect) read s.conn + // through send while the session goroutine reconnects; synchronize the + // swap with the same lock send holds. + s.writeMu.Lock() s.conn = conn + s.writeMu.Unlock() slog.Debug("reflection V2: connected", "url", s.opts.URL) return nil } @@ -252,6 +266,10 @@ func (s *reflectionServerV2) register(ctx context.Context) { // readLoop reads and dispatches JSON-RPC messages until the context is // cancelled or the connection is closed. func (s *reflectionServerV2) readLoop(ctx context.Context) { + // If the client disconnects mid-stream without sending endInputStream, + // end the input of any in-flight bidi sessions so action bodies awaiting + // input can terminate instead of hanging (matching the JS runtime). + defer s.closeBidiSessions() for { var msg jsonRPCMessage if err := wsjson.Read(ctx, s.conn, &msg); err != nil { @@ -264,7 +282,32 @@ func (s *reflectionServerV2) readLoop(ctx context.Context) { continue } if msg.Method != "" { - go s.handleRequest(ctx, &msg) + switch msg.Method { + case "runAction": + // Pre-register a bidi session before dispatching the handler + // so that later sendInputStreamChunk / endInputStream + // notifications (which run in their own goroutines) always + // find it. + if msg.ID != "" { + var peek struct { + StreamInput bool `json:"streamInput"` + } + if len(msg.Params) > 0 { + _ = json.Unmarshal(msg.Params, &peek) + } + if peek.StreamInput { + s.registerBidiSession(msg.ID, newBidiSession()) + } + } + go s.handleRequest(ctx, &msg) + case "sendInputStreamChunk", "endInputStream": + // Dispatch synchronously to preserve wire ordering when + // enqueueing onto the per-session event queue. Enqueueing + // never blocks, so this cannot stall the read loop. + s.handleRequest(ctx, &msg) + default: + go s.handleRequest(ctx, &msg) + } } else if msg.ID != "" { s.deliverResponse(&msg) } @@ -287,9 +330,10 @@ func (s *reflectionServerV2) handleRequest(ctx context.Context, req *jsonRPCMess s.handleCancelAction(req) case "configure": s.handleConfigure(req) - case "sendInputStreamChunk", "endInputStream": - // Bidirectional input streaming is not yet implemented. - slog.Debug("reflection V2: method not implemented", "method", req.Method) + case "sendInputStreamChunk": + s.handleSendInputStreamChunk(req) + case "endInputStream": + s.handleEndInputStream(req) default: if req.ID != "" { s.sendErrorResponse(req.ID, jsonRPCMethodNotFound, "method not found: "+req.Method, nil) @@ -340,44 +384,36 @@ func (s *reflectionServerV2) handleRunAction(ctx context.Context, req *jsonRPCMe if req.ID == "" { return } + // Owns cleanup for any bidi session pre-registered by readLoop, including + // early-return paths. The session is captured up front and unregistration + // is ownership-checked: a long-lived handler must not tear down a newer + // session registered under a reused request id (the manager's id counter + // restarts when the CLI restarts and reconnects). + session := s.lookupBidiSession(req.ID) + defer s.unregisterBidiSession(req.ID, session) + var params ReflectionRunActionParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { s.sendErrorResponse(req.ID, jsonRPCInvalidParams, "invalid params: "+err.Error(), nil) return } - slog.Debug("reflection V2: running action", "key", params.Key, "stream", params.Stream) + slog.Debug("reflection V2: running action", "key", params.Key, "stream", params.Stream, "streamInput", params.StreamInput) + + if params.StreamInput { + s.handleRunActionBidi(ctx, req, ¶ms, session) + return + } actionCtx, cancel := context.WithCancel(ctx) defer cancel() - var traceIDMu sync.Mutex - var traceID string - - telemetryCb := func(tid, _ string) { - traceIDMu.Lock() - traceID = tid - traceIDMu.Unlock() - - s.activeActions.Set(tid, &activeAction{ - cancel: cancel, - startTime: time.Now(), - traceID: tid, - }) - - s.sendNotification("runActionState", &ReflectionRunActionStateParams{ - RequestID: req.ID, - State: &ReflectionRunActionStateParamsState{TraceID: tid}, - }) - } + rt := s.newRunActionTelemetry(req.ID, cancel) var streamCb streamingCallback[json.RawMessage] if params.Stream { streamCb = func(_ context.Context, chunk json.RawMessage) error { - return s.sendNotification("streamChunk", &ReflectionStreamChunkParams{ - RequestID: req.ID, - Chunk: chunk, - }) + return s.sendStreamChunk(req.ID, chunk) } } @@ -389,15 +425,10 @@ func (s *reflectionServerV2) handleRunAction(ctx context.Context, req *jsonRPCMe } } - actionCtx = tracing.WithTelemetryCallback(actionCtx, telemetryCb) - resp, err := runAction(actionCtx, s.g, params.Key, params.Input, params.TelemetryLabels, streamCb, contextMap) + actionCtx = tracing.WithTelemetryCallback(actionCtx, rt.callback) + resp, err := runAction(actionCtx, s.g, params.Key, params.Input, params.Init, params.TelemetryLabels, streamCb, contextMap) - traceIDMu.Lock() - capturedTraceID := traceID - traceIDMu.Unlock() - if capturedTraceID != "" { - s.activeActions.Delete(capturedTraceID) - } + capturedTraceID := rt.finish() if err != nil { s.sendRunActionError(req.ID, err, capturedTraceID) @@ -410,6 +441,340 @@ func (s *reflectionServerV2) handleRunAction(ctx context.Context, req *jsonRPCMe }) } +// runActionTelemetry tracks the trace ID for one in-flight runAction request: +// the span-start callback records the id, registers the run as a cancellable +// active action, and notifies the client. finish marks the request complete +// and returns the captured trace ID; callbacks firing after finish are +// ignored, since a bidi session's span is created asynchronously and can +// start after the handler has already responded. +type runActionTelemetry struct { + s *reflectionServerV2 + reqID string + cancel context.CancelFunc + + mu sync.Mutex + traceID string + finished bool +} + +func (s *reflectionServerV2) newRunActionTelemetry(reqID string, cancel context.CancelFunc) *runActionTelemetry { + return &runActionTelemetry{s: s, reqID: reqID, cancel: cancel} +} + +// callback is the tracing telemetry callback for the run's root span. +func (rt *runActionTelemetry) callback(tid, _ string) { + rt.mu.Lock() + if rt.finished { + rt.mu.Unlock() + return + } + rt.traceID = tid + // Registered under the lock so finish either sees the id (and deletes + // the entry) or this callback sees finished (and never registers). + rt.s.activeActions.Set(tid, &activeAction{ + cancel: rt.cancel, + startTime: time.Now(), + traceID: tid, + }) + rt.mu.Unlock() + + rt.s.sendNotification("runActionState", &ReflectionRunActionStateParams{ + RequestID: rt.reqID, + State: &ReflectionRunActionStateParamsState{TraceID: tid}, + }) +} + +// finish marks the request complete, removes the active-action entry, and +// returns the trace ID captured for the run (empty if the span never started). +func (rt *runActionTelemetry) finish() string { + rt.mu.Lock() + rt.finished = true + tid := rt.traceID + rt.mu.Unlock() + if tid != "" { + rt.s.activeActions.Delete(tid) + } + return tid +} + +// sendStreamChunk forwards one output chunk to the client as a streamChunk +// notification. +func (s *reflectionServerV2) sendStreamChunk(requestID string, chunk json.RawMessage) error { + return s.sendNotification("streamChunk", &ReflectionStreamChunkParams{ + RequestID: requestID, + Chunk: chunk, + }) +} + +// handleRunActionBidi handles a runAction request with streamInput=true. +// It resolves the action as a bidi action, wires its input/output streams to +// the JSON-RPC connection, and waits for the final result. The session has +// already been pre-registered by readLoop. +func (s *reflectionServerV2) handleRunActionBidi(ctx context.Context, req *jsonRPCMessage, params *ReflectionRunActionParams, session *bidiSession) { + if session == nil { + // readLoop pre-registers a session for every streamInput run, so a + // missing one means it was already torn down (e.g. a reused request + // id); fail loud rather than dereference it below. + s.sendErrorResponse(req.ID, jsonRPCServerError, "bidi session unavailable for request "+req.ID, nil) + return + } + + action := s.g.reg.ResolveAction(params.Key) + if action == nil { + s.sendErrorResponse(req.ID, jsonRPCInvalidParams, fmt.Sprintf("action not found: %s", params.Key), nil) + return + } + bidi, ok := action.(api.BidiAction) + if !ok { + s.sendErrorResponse(req.ID, jsonRPCInvalidParams, fmt.Sprintf("action %s does not support bidirectional streaming", params.Key), nil) + return + } + + actionCtx, cancel := context.WithCancel(ctx) + defer cancel() + + rt := s.newRunActionTelemetry(req.ID, cancel) + actionCtx = tracing.WithTelemetryCallback(actionCtx, rt.callback) + + if params.Context != nil { + contextMap := core.ActionContext{} + if err := json.Unmarshal(params.Context, &contextMap); err != nil { + s.sendErrorResponse(req.ID, jsonRPCInvalidParams, "invalid context: "+err.Error(), nil) + return + } + actionCtx = core.WithActionContext(actionCtx, contextMap) + } + + conn, err := bidi.StreamBidiJSON(actionCtx, &api.BidiSessionOptions{Init: params.Init}) + if err != nil { + s.sendErrorResponse(req.ID, jsonRPCServerError, err.Error(), nil) + return + } + + // Start consuming outgoing stream chunks before replaying buffered + // inputs, so the action's outbound channel can drain while we feed + // inbound chunks. Chunks are forwarded to the client only when output + // streaming was requested (matching the JS runtime), but the stream is + // always drained since the connection applies backpressure to the action. + forwardDone := make(chan struct{}) + go func() { + defer close(forwardDone) + forward := params.Stream + for chunk, rerr := range conn.Receive() { + if rerr != nil { + return + } + if !forward { + continue + } + if err := s.sendStreamChunk(req.ID, chunk); err != nil { + slog.Debug("reflection V2: streamChunk send failed", "err", err) + forward = false + } + } + }() + + // Hand the real connection to the pre-registered session. Any chunks + // that arrived while we were resolving the action are replayed now. + go session.run(conn) + + output, runErr := conn.Output() + <-forwardDone + + capturedTraceID := rt.finish() + + if runErr != nil { + s.sendRunActionError(req.ID, runErr, capturedTraceID) + return + } + + s.sendResponse(req.ID, &reflectionRunActionResponse{ + Result: output, + Telemetry: telemetry{TraceID: capturedTraceID}, + }) +} + +// handleSendInputStreamChunk routes an inbound chunk to the bidi session +// identified by RequestID. If the session has not yet attached its underlying +// connection, the chunk is buffered and replayed when it does. +func (s *reflectionServerV2) handleSendInputStreamChunk(req *jsonRPCMessage) { + var params ReflectionSendInputStreamChunkParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + slog.Debug("reflection V2: invalid sendInputStreamChunk params", "err", err) + return + } + session := s.lookupBidiSession(params.RequestID) + if session == nil { + slog.Debug("reflection V2: sendInputStreamChunk for unknown session", "requestId", params.RequestID) + return + } + session.Send(params.Chunk) +} + +// handleEndInputStream closes the input stream of the bidi session identified +// by RequestID. If the session has not yet attached its connection, the end +// signal is buffered. +func (s *reflectionServerV2) handleEndInputStream(req *jsonRPCMessage) { + var params ReflectionEndInputStreamParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + slog.Debug("reflection V2: invalid endInputStream params", "err", err) + return + } + session := s.lookupBidiSession(params.RequestID) + if session == nil { + slog.Debug("reflection V2: endInputStream for unknown session", "requestId", params.RequestID) + return + } + session.Close() +} + +func (s *reflectionServerV2) registerBidiSession(id string, session *bidiSession) { + s.bidiMu.Lock() + old := s.bidiSessions[id] + s.bidiSessions[id] = session + s.bidiMu.Unlock() + // A reused request id orphans the previous session; stop its worker so + // it cannot linger blocked on an empty queue. + if old != nil { + old.stop() + } +} + +// closeBidiSessions enqueues an end-of-input marker for every in-flight bidi +// session. Used when the connection drops, since the client can no longer +// send endInputStream; ending the input lets actions finish gracefully. +func (s *reflectionServerV2) closeBidiSessions() { + s.bidiMu.Lock() + sessions := make([]*bidiSession, 0, len(s.bidiSessions)) + for _, session := range s.bidiSessions { + sessions = append(sessions, session) + } + s.bidiMu.Unlock() + for _, session := range sessions { + session.Close() + } +} + +// unregisterBidiSession removes session from the map, unless a newer session +// has been registered under the same id, and stops its worker goroutine. A +// nil session is a no-op. Safe to call multiple times for the same session. +func (s *reflectionServerV2) unregisterBidiSession(id string, session *bidiSession) { + if session == nil { + return + } + s.bidiMu.Lock() + if s.bidiSessions[id] == session { + delete(s.bidiSessions, id) + } + s.bidiMu.Unlock() + session.stop() +} + +func (s *reflectionServerV2) lookupBidiSession(id string) *bidiSession { + s.bidiMu.Lock() + defer s.bidiMu.Unlock() + return s.bidiSessions[id] +} + +// bidiSession is the runtime-side handle for an in-flight bidi runAction call. +// All input events (chunks plus a terminating close) are queued in arrival +// order from the read loop. The session is pre-registered before the action +// starts initializing, so events that arrive early simply accumulate; once +// the handler has created the underlying api.BidiJSONConnection it starts a +// run worker goroutine that drains the queue into the connection in order. +// +// The queue is unbounded so that enqueueing never blocks the WebSocket read +// loop, which must stay responsive to process cancelAction (the escape hatch +// for a stuck action). Memory is bounded by the client, which is the trusted +// dev tooling; the JS runtime makes the same trade. +type bidiSession struct { + mu sync.Mutex + cond *sync.Cond + events []bidiEvent + stopped bool +} + +// bidiEvent is one item delivered to the worker. close=true is a terminal +// marker; any chunks queued after it are dropped. +type bidiEvent struct { + chunk json.RawMessage + close bool +} + +func newBidiSession() *bidiSession { + s := &bidiSession{} + s.cond = sync.NewCond(&s.mu) + return s +} + +// run forwards queued events to the connection in order. Returns after a +// close event or when the session is stopped. +func (s *bidiSession) run(conn api.BidiJSONConnection) { + for { + ev, ok := s.next() + if !ok { + return + } + if ev.close { + _ = conn.Close() + return + } + if err := conn.Send(ev.chunk); err != nil { + slog.Debug("reflection V2: bidi Send failed", "err", err) + } + } +} + +// next blocks until an event is queued or the session is stopped. Returns +// ok=false when stopped; events still queued at that point are dropped, as +// the session is being torn down. +func (s *bidiSession) next() (bidiEvent, bool) { + s.mu.Lock() + defer s.mu.Unlock() + for len(s.events) == 0 && !s.stopped { + s.cond.Wait() + } + if s.stopped { + return bidiEvent{}, false + } + ev := s.events[0] + s.events[0] = bidiEvent{} // Release the chunk for GC. + s.events = s.events[1:] + if len(s.events) == 0 { + s.events = nil // Release the backing array. + } + return ev, true +} + +// Send enqueues a chunk for delivery to the action. It never blocks. +func (s *bidiSession) Send(chunk json.RawMessage) { + s.enqueue(bidiEvent{chunk: chunk}) +} + +// Close enqueues a terminal end-of-input marker. It never blocks. +func (s *bidiSession) Close() { + s.enqueue(bidiEvent{close: true}) +} + +func (s *bidiSession) enqueue(ev bidiEvent) { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped { + return + } + s.events = append(s.events, ev) + s.cond.Signal() +} + +// stop terminates the worker and drops any queued events. Safe to call +// multiple times. +func (s *bidiSession) stop() { + s.mu.Lock() + defer s.mu.Unlock() + s.stopped = true + s.cond.Broadcast() +} + // sendRunActionError maps a runAction error to a JSON-RPC error response // with a Status-shaped data field matching the JS implementation. func (s *reflectionServerV2) sendRunActionError(id string, err error, traceID string) { diff --git a/go/genkit/reflection_v2_test.go b/go/genkit/reflection_v2_test.go index d94c9f1164..468c89f57c 100644 --- a/go/genkit/reflection_v2_test.go +++ b/go/genkit/reflection_v2_test.go @@ -21,6 +21,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "slices" "strings" "sync" "testing" @@ -541,6 +542,370 @@ func TestReflectionServerV2_CancelAction(t *testing.T) { } } +func TestReflectionServerV2_BidiRunAction(t *testing.T) { + m := newFakeManager(t) + defer m.close() + + g := Init(context.Background()) + + type initConfig struct { + Prefix string `json:"prefix"` + } + + core.DefineBidiAction(g.reg, "test/bidi-echo", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg initConfig, inCh <-chan string, outCh chan<- string) (string, error) { + var n int + for chunk := range inCh { + n++ + outCh <- cfg.Prefix + chunk + } + return "processed", nil + }) + + ctx, cancel := startRuntime(t, g, m) + defer cancel() + + conn := m.waitForConnection(t) + m.ackRegister(t, ctx, conn) + + // Start the bidi run. `init` carries the per-session configuration; + // streamInput signals that chunks will be sent via sendInputStreamChunk; + // stream requests that output chunks be forwarded back (chunks are not + // forwarded without it, matching the JS runtime). + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "runAction", + "params": map[string]any{ + "key": "/custom/test/bidi-echo", + "init": map[string]any{"prefix": "> "}, + "stream": true, + "streamInput": true, + }, + "id": "bidi-1", + }) + + // Send two chunks and end the input stream. + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "bidi-1", "chunk": "hello"}, + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "bidi-1", "chunk": "world"}, + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "endInputStream", + "params": map[string]any{"requestId": "bidi-1"}, + }) + + var chunks []string + var final map[string]any + deadline := time.After(5 * time.Second) +loop: + for { + select { + case <-deadline: + t.Fatalf("timed out; chunks=%v final=%v", chunks, final) + default: + } + msg := m.read(t, ctx, conn) + switch msg["method"] { + case "streamChunk": + params := msg["params"].(map[string]any) + if params["requestId"] != "bidi-1" { + t.Errorf("streamChunk requestId = %v, want bidi-1", params["requestId"]) + } + chunks = append(chunks, params["chunk"].(string)) + case "runActionState": + // early trace-id notification; ignore + default: + final = msg + break loop + } + } + + if got, want := chunks, []string{"> hello", "> world"}; !slices.Equal(got, want) { + t.Errorf("chunks = %v, want %v", got, want) + } + result, ok := final["result"].(map[string]any) + if !ok { + t.Fatalf("expected result object, got %v", final) + } + var out string + if err := json.Unmarshal([]byte(toJSON(t, result["result"])), &out); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + if out != "processed" { + t.Errorf("result = %q, want %q", out, "processed") + } +} + +func toJSON(t *testing.T, v any) string { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return string(b) +} + +// TestReflectionServerV2_BidiRunActionDropsAfterEnd verifies that chunks +// arriving after endInputStream are not delivered to the action, even when +// they queue up before the underlying connection has been attached. +func TestReflectionServerV2_BidiRunActionDropsAfterEnd(t *testing.T) { + m := newFakeManager(t) + defer m.close() + + g := Init(context.Background()) + + // The action records every chunk it sees so the test can assert that + // chunks after endInputStream were dropped. + var seenMu sync.Mutex + var seen []string + + core.DefineBidiAction(g.reg, "test/bidi-record", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for chunk := range inCh { + seenMu.Lock() + seen = append(seen, chunk) + seenMu.Unlock() + } + return "done", nil + }) + + ctx, cancel := startRuntime(t, g, m) + defer cancel() + + conn := m.waitForConnection(t) + m.ackRegister(t, ctx, conn) + + // Pipeline runAction + chunks + end + extra chunk back-to-back. All + // arrive before the handler goroutine is likely to have started the + // session worker, so they queue inside the bidi session. The "extra" + // chunk is enqueued after the close marker and must be dropped by the + // worker. + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "runAction", + "params": map[string]any{ + "key": "/custom/test/bidi-record", + "streamInput": true, + }, + "id": "drop-1", + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "drop-1", "chunk": "a"}, + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "drop-1", "chunk": "b"}, + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "endInputStream", + "params": map[string]any{"requestId": "drop-1"}, + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "drop-1", "chunk": "after-end"}, + }) + + // Drain notifications until the final response. + deadline := time.After(5 * time.Second) +loop: + for { + select { + case <-deadline: + t.Fatalf("timed out waiting for response") + default: + } + msg := m.read(t, ctx, conn) + if _, hasResult := msg["result"]; hasResult { + break loop + } + if _, hasErr := msg["error"]; hasErr { + t.Fatalf("unexpected error response: %v", msg) + } + } + + seenMu.Lock() + defer seenMu.Unlock() + if want := []string{"a", "b"}; !slices.Equal(seen, want) { + t.Errorf("action received %v, want %v (chunk after endInputStream should be dropped)", seen, want) + } +} + +// TestReflectionServerV2_BidiRunActionErrors verifies that an error returned +// from a bidi action surfaces as a JSON-RPC error response. +func TestReflectionServerV2_BidiRunActionErrors(t *testing.T) { + m := newFakeManager(t) + defer m.close() + + g := Init(context.Background()) + + core.DefineBidiAction(g.reg, "test/bidi-fail", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "", core.NewError(core.INVALID_ARGUMENT, "boom") + }) + + ctx, cancel := startRuntime(t, g, m) + defer cancel() + + conn := m.waitForConnection(t) + m.ackRegister(t, ctx, conn) + + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "runAction", + "params": map[string]any{ + "key": "/custom/test/bidi-fail", + "streamInput": true, + }, + "id": "err-1", + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "endInputStream", + "params": map[string]any{"requestId": "err-1"}, + }) + + deadline := time.After(5 * time.Second) + for { + select { + case <-deadline: + t.Fatalf("timed out waiting for error response") + default: + } + msg := m.read(t, ctx, conn) + if _, hasResult := msg["result"]; hasResult { + t.Fatalf("expected error, got result: %v", msg) + } + errObj, ok := msg["error"].(map[string]any) + if !ok { + continue // ignore notifications (e.g., runActionState) + } + if !strings.Contains(errObj["message"].(string), "boom") { + t.Errorf("error message = %q, want substring %q", errObj["message"], "boom") + } + return + } +} + +// TestReflectionServerV2_BidiInvalidChunkFailsRun verifies that a chunk +// failing input-schema validation fails the whole run with the validation +// error (matching the JS runtime) rather than being silently dropped. +func TestReflectionServerV2_BidiInvalidChunkFailsRun(t *testing.T) { + m := newFakeManager(t) + defer m.close() + + g := Init(context.Background()) + + core.DefineBidiAction(g.reg, "test/bidi-strict", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for { + select { + case _, ok := <-inCh: + if !ok { + return "done", nil + } + case <-ctx.Done(): + return "", ctx.Err() + } + } + }) + + ctx, cancel := startRuntime(t, g, m) + defer cancel() + + conn := m.waitForConnection(t) + m.ackRegister(t, ctx, conn) + + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "runAction", + "params": map[string]any{ + "key": "/custom/test/bidi-strict", + "streamInput": true, + }, + "id": "invalid-chunk-1", + }) + // A number where the action expects a string: fails validation and must + // fail the run. + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "invalid-chunk-1", "chunk": 123}, + }) + + deadline := time.After(5 * time.Second) + for { + select { + case <-deadline: + t.Fatal("timed out waiting for error response") + default: + } + msg := m.read(t, ctx, conn) + if _, hasResult := msg["result"]; hasResult { + t.Fatalf("expected error, got result: %v", msg) + } + errObj, ok := msg["error"].(map[string]any) + if !ok { + continue // ignore notifications (e.g., runActionState) + } + if !strings.Contains(errObj["message"].(string), "invalid stream chunk") { + t.Errorf("error message = %q, want substring %q", errObj["message"], "invalid stream chunk") + } + return + } +} + +// TestReflectionServerV2_BidiSessionOwnership verifies the ownership rules +// that protect bidi sessions against request-id reuse: a reused id stops the +// orphaned session, a stale handler's unregister cannot tear down a newer +// session under the same id, and the owner's unregister removes and stops its +// own session. +func TestReflectionServerV2_BidiSessionOwnership(t *testing.T) { + s := &reflectionServerV2{bidiSessions: map[string]*bidiSession{}} + + s1 := newBidiSession() + s.registerBidiSession("1", s1) + + // A reused id orphans and stops the previous session. + s2 := newBidiSession() + s.registerBidiSession("1", s2) + if _, ok := s1.next(); ok { + t.Error("orphaned session worker not stopped on id reuse") + } + + // A stale handler unregistering with its old session must not touch the + // newer one. + s.unregisterBidiSession("1", s1) + if got := s.lookupBidiSession("1"); got != s2 { + t.Errorf("stale unregister removed the new session; lookup = %v, want s2", got) + } + + // The owner's unregister removes and stops its session. + s.unregisterBidiSession("1", s2) + if got := s.lookupBidiSession("1"); got != nil { + t.Error("session still registered after owner unregister") + } + if _, ok := s2.next(); ok { + t.Error("session worker not stopped after owner unregister") + } + + // A nil session is a no-op. + s.unregisterBidiSession("1", nil) +} + func TestReflectionServerV2_MethodNotFound(t *testing.T) { m := newFakeManager(t) defer m.close() @@ -567,3 +932,132 @@ func TestReflectionServerV2_MethodNotFound(t *testing.T) { t.Errorf("code = %v, want %d", code, jsonRPCMethodNotFound) } } + +// TestReflectionServerV2_BidiRunActionNoStream verifies that output chunks +// are not forwarded when the request does not ask for output streaming +// (matching the JS runtime), while the final result still arrives. +func TestReflectionServerV2_BidiRunActionNoStream(t *testing.T) { + m := newFakeManager(t) + defer m.close() + + g := Init(context.Background()) + + core.DefineBidiAction(g.reg, "test/bidi-quiet", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var last string + for chunk := range inCh { + last = chunk + outCh <- chunk + } + return "got " + last, nil + }) + + ctx, cancel := startRuntime(t, g, m) + defer cancel() + + conn := m.waitForConnection(t) + m.ackRegister(t, ctx, conn) + + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "runAction", + "params": map[string]any{ + "key": "/custom/test/bidi-quiet", + "streamInput": true, + }, + "id": "quiet-1", + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "quiet-1", "chunk": "hello"}, + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "endInputStream", + "params": map[string]any{"requestId": "quiet-1"}, + }) + + var final map[string]any + deadline := time.After(5 * time.Second) +loop: + for { + select { + case <-deadline: + t.Fatal("timed out waiting for final response") + default: + } + msg := m.read(t, ctx, conn) + switch msg["method"] { + case "streamChunk": + t.Errorf("unexpected streamChunk without stream=true: %v", msg) + case "runActionState": + // early trace-id notification; ignore + default: + final = msg + break loop + } + } + + result, ok := final["result"].(map[string]any) + if !ok { + t.Fatalf("expected result object, got %v", final) + } + var out string + if err := json.Unmarshal([]byte(toJSON(t, result["result"])), &out); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + if out != "got hello" { + t.Errorf("result = %q, want %q", out, "got hello") + } +} + +// TestReflectionServerV2_BidiInputClosedOnDisconnect verifies that when the +// manager connection drops mid-stream, in-flight bidi actions see their input +// stream end instead of hanging forever awaiting chunks. +func TestReflectionServerV2_BidiInputClosedOnDisconnect(t *testing.T) { + m := newFakeManager(t) + defer m.close() + + g := Init(context.Background()) + + inputEnded := make(chan struct{}) + core.DefineBidiAction(g.reg, "test/bidi-hang", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + close(inputEnded) + return "done", nil + }) + + ctx, cancel := startRuntime(t, g, m) + defer cancel() + + conn := m.waitForConnection(t) + m.ackRegister(t, ctx, conn) + + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "runAction", + "params": map[string]any{ + "key": "/custom/test/bidi-hang", + "streamInput": true, + }, + "id": "hang-1", + }) + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "sendInputStreamChunk", + "params": map[string]any{"requestId": "hang-1", "chunk": "hello"}, + }) + + // Drop the connection without sending endInputStream. The runtime must + // end the action's input stream so it can finish. + m.close() + + select { + case <-inputEnded: + case <-time.After(5 * time.Second): + t.Fatal("bidi action input stream was not closed on disconnect") + } +} diff --git a/go/genkit/servers.go b/go/genkit/servers.go index 99090cd4c7..dbaad093bf 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -166,6 +166,7 @@ func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http var body struct { Data json.RawMessage `json:"data"` + Init json.RawMessage `json:"init,omitempty"` // Per-session init for bidi actions; rejected otherwise. } if r.Body != nil && r.ContentLength > 0 { defer r.Body.Close() @@ -174,6 +175,14 @@ func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http } } + run := func(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { + r, err := runActionWithOptionalInit(ctx, a, input, body.Init, cb) + if err != nil { + return nil, err + } + return r.Result, nil + } + stream, err := parseBoolQueryParam(r, "stream") if err != nil { return err @@ -219,14 +228,14 @@ func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http w.Header().Set("Transfer-Encoding", "chunked") if opts.StreamManager != nil { - return runWithDurableStreaming(ctx, w, a, opts.StreamManager, body.Data) + return runWithDurableStreaming(ctx, w, run, opts.StreamManager, body.Data) } - return runWithStreaming(ctx, w, a, body.Data) + return runWithStreaming(ctx, w, run, body.Data) } w.Header().Set("Content-Type", "application/json") - out, err := a.RunJSON(ctx, body.Data, nil) + out, err := run(ctx, body.Data, nil) if err != nil { return err } @@ -234,8 +243,12 @@ func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http } } +// runJSONFunc abstracts over RunJSON and RunBidiJSON for the handler's +// execution paths. +type runJSONFunc = func(context.Context, json.RawMessage, func(context.Context, json.RawMessage) error) (json.RawMessage, error) + // runWithStreaming executes the action with standard HTTP streaming (no durability). -func runWithStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, input json.RawMessage) error { +func runWithStreaming(ctx context.Context, w http.ResponseWriter, run runJSONFunc, input json.RawMessage) error { callback := func(ctx context.Context, msg json.RawMessage) error { if err := writeSSEMessage(w, msg); err != nil { return err @@ -246,7 +259,7 @@ func runWithStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, return nil } - out, err := a.RunJSON(ctx, input, callback) + out, err := run(ctx, input, callback) if err != nil { if werr := writeSSEError(w, err); werr != nil { return werr @@ -263,7 +276,7 @@ func runWithStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, // original client disconnects, the flow continues running and writing to durable // storage. This allows other clients to subscribe to the stream and receive the // remaining chunks and final result. -func runWithDurableStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, sm streaming.StreamManager, input json.RawMessage) error { +func runWithDurableStreaming(ctx context.Context, w http.ResponseWriter, run runJSONFunc, sm streaming.StreamManager, input json.RawMessage) error { streamID := uuid.New().String() durableStream, err := sm.Open(ctx, streamID) @@ -301,7 +314,7 @@ func runWithDurableStreaming(ctx context.Context, w http.ResponseWriter, a api.A return nil } - out, err := a.RunJSON(durableCtx, input, callback) + out, err := run(durableCtx, input, callback) if err != nil { durableStream.Error(durableCtx, err) select { diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 068dc49dbf..a18a109a40 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/x/streaming" ) @@ -556,3 +557,127 @@ data: {"result":"ab-done"} } }) } + +// TestHandlerBidiInitEnvelope verifies that an HTTP POST to a bidi action +// handler can supply Init via the request envelope's "init" field, alongside +// the existing "data" field. This is the production HTTP path for bidi +// actions invoked as one-shots. +func TestHandlerBidiInitEnvelope(t *testing.T) { + g := Init(context.Background()) + + type Config struct { + Prefix string `json:"prefix"` + } + + bidiAction := core.DefineBidiAction(g.reg, "envelopeBidi", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + for in := range inCh { + outCh <- cfg.Prefix + in + } + return "done", nil + }) + + t.Run("non-streaming envelope with init", func(t *testing.T) { + handler := Handler(bidiAction) + + body := `{"data":"hello","init":{"prefix":">> "}}` + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body = %s", resp.StatusCode, string(respBody)) + } + // Result should be "done"; the prefixed chunk goes to the streaming + // callback (nil here), so the final output is the function's return. + if !strings.Contains(string(respBody), `"done"`) { + t.Errorf("response body = %q, want it to contain \"done\"", string(respBody)) + } + }) + + t.Run("streaming envelope with init delivers prefixed chunk", func(t *testing.T) { + handler := HandlerFunc(bidiAction) + + // Use an HTML-safe prefix so json.Marshal doesn't escape it; that + // way the assertion can match the prefix literally in the SSE body. + body := `{"data":"hello","init":{"prefix":"PFX:"}}` + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + w := httptest.NewRecorder() + + if err := handler(w, req); err != nil { + t.Fatalf("handler: %v", err) + } + + respBody, _ := io.ReadAll(w.Result().Body) + if !strings.Contains(string(respBody), "PFX:hello") { + t.Errorf("response missing prefixed chunk; body = %q", string(respBody)) + } + if !strings.Contains(string(respBody), `"done"`) { + t.Errorf("response missing final result; body = %q", string(respBody)) + } + }) + + t.Run("envelope without init uses zero value", func(t *testing.T) { + handler := Handler(bidiAction) + + body := `{"data":"hello"}` + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, body = %s", resp.StatusCode, string(respBody)) + } + }) + + t.Run("envelope with malformed init returns 400", func(t *testing.T) { + handler := Handler(bidiAction) + + // Init is valid JSON but doesn't match the action's Config (prefix + // must be a string; here it's a number). + body := `{"data":"hello","init":{"prefix":42}}` + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + respBody, _ := io.ReadAll(resp.Body) + t.Errorf("status = %d, want %d; body = %s", resp.StatusCode, http.StatusBadRequest, string(respBody)) + } + }) + + t.Run("init on non-bidi flow returns 400", func(t *testing.T) { + plainFlow := DefineFlow(g, "envelopePlain", + func(ctx context.Context, in string) (string, error) { + return "out:" + in, nil + }) + handler := Handler(plainFlow) + + body := `{"data":"hello","init":{"prefix":">> "}}` + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + respBody, _ := io.ReadAll(resp.Body) + t.Errorf("status = %d, want %d; body = %s", resp.StatusCode, http.StatusBadRequest, string(respBody)) + } + }) +} diff --git a/go/genkit/x/genkit.go b/go/genkit/x/genkit.go index 8574d338ea..9480f75149 100644 --- a/go/genkit/x/genkit.go +++ b/go/genkit/x/genkit.go @@ -82,7 +82,7 @@ type StreamingFunc[In, Out, Stream any] = func(ctx context.Context, input In, st // fmt.Println(val.Stream) // 5, 4, 3, 2, 1 // } // } -func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream, struct{}] { +func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { // Wrap the channel-based function to work with the callback-based API wrappedFn := func(ctx context.Context, input In, sendChunk core.StreamCallback[Stream]) (Out, error) { if sendChunk == nil { diff --git a/go/internal/base/json.go b/go/internal/base/json.go index fea27bb00f..44fe752667 100644 --- a/go/internal/base/json.go +++ b/go/internal/base/json.go @@ -17,6 +17,7 @@ package base import ( + "bytes" "encoding/json" "errors" "fmt" @@ -29,6 +30,13 @@ import ( "github.com/invopop/jsonschema" ) +// HasJSONValue reports whether raw carries an actual JSON value: it is +// non-empty and not the JSON null literal, ignoring surrounding whitespace. +func HasJSONValue(raw json.RawMessage) bool { + trimmed := bytes.TrimSpace(raw) + return len(trimmed) > 0 && !bytes.Equal(trimmed, []byte("null")) +} + // JSONString returns json.Marshal(x) as a string. If json.Marshal returns // an error, jsonString returns the error text as a JSON string beginning "ERROR:". func JSONString(x any) string { diff --git a/go/internal/base/json_type_converter.go b/go/internal/base/json_type_converter.go index edcb2b7cc0..f097762996 100644 --- a/go/internal/base/json_type_converter.go +++ b/go/internal/base/json_type_converter.go @@ -147,6 +147,14 @@ func normalizeArrayInput(arr []any, schema map[string]any) ([]any, error) { // For 'any' types, it preserves the actual types from the normalized data. // For structured types, it marshals and unmarshals to properly populate the fields. func UnmarshalAndNormalize[T any](input json.RawMessage, schema map[string]any) (T, error) { + return UnmarshalAndNormalizeWith[T](input, schema, nil) +} + +// UnmarshalAndNormalizeWith is UnmarshalAndNormalize with an optional +// precompiled schema: when compiled is non-nil, validation uses it instead of +// recompiling schema, which matters on per-chunk streaming hot paths. schema +// is still used for normalization and must describe the same schema. +func UnmarshalAndNormalizeWith[T any](input json.RawMessage, schema map[string]any, compiled *CompiledSchema) (T, error) { var zero T if len(input) == 0 { @@ -163,7 +171,12 @@ func UnmarshalAndNormalize[T any](input json.RawMessage, schema map[string]any) return zero, fmt.Errorf("invalid input: %w", err) } - if err := ValidateValue(normalized, schema); err != nil { + if compiled != nil { + err = compiled.ValidateValue(normalized) + } else { + err = ValidateValue(normalized, schema) + } + if err != nil { return zero, err } diff --git a/go/internal/base/validation.go b/go/internal/base/validation.go index e363b95523..430c6414e7 100644 --- a/go/internal/base/validation.go +++ b/go/internal/base/validation.go @@ -66,16 +66,63 @@ func ValidateRaw(dataBytes json.RawMessage, schemaBytes json.RawMessage) error { if err != nil { return fmt.Errorf("failed to validate data against expected schema: %w", err) } + return validationResultError(result) +} + +// validationResultError converts a gojsonschema result into the package's +// standard validation error, or nil when the result is valid. +func validationResultError(result *gojsonschema.Result) error { + if result.Valid() { + return nil + } + var errs []string + for _, err := range result.Errors() { + errs = append(errs, fmt.Sprintf("- %s", err)) + } + return fmt.Errorf("data did not match expected schema:\n%s", strings.Join(errs, "\n")) +} - if !result.Valid() { - var errors []string - for _, err := range result.Errors() { - errors = append(errors, fmt.Sprintf("- %s", err)) - } - return fmt.Errorf("data did not match expected schema:\n%s", strings.Join(errors, "\n")) +// CompiledSchema is a JSON schema precompiled for repeated validation, e.g. +// per-chunk validation on streaming transports, where recompiling the schema +// for every payload would dominate the hot path. A nil *CompiledSchema (from +// a nil schema) accepts every value, matching ValidateValue's nil handling. +type CompiledSchema struct { + schema *gojsonschema.Schema +} + +// CompileSchema compiles schema for repeated validation with +// [CompiledSchema.ValidateValue]. A nil schema compiles to a nil +// CompiledSchema, which accepts every value. +func CompileSchema(schema map[string]any) (*CompiledSchema, error) { + if schema == nil { + return nil, nil + } + schemaBytes, err := json.Marshal(schema) + if err != nil { + return nil, fmt.Errorf("expected schema is not valid: %w", err) } + compiled, err := gojsonschema.NewSchema(gojsonschema.NewBytesLoader(schemaBytes)) + if err != nil { + return nil, fmt.Errorf("expected schema is not valid: %w", err) + } + return &CompiledSchema{schema: compiled}, nil +} - return nil +// ValidateValue validates data against the compiled schema, with the same +// behavior and error shape as [ValidateValue]. +func (c *CompiledSchema) ValidateValue(data any) error { + if c == nil { + return nil + } + dataBytes, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("data is not a valid JSON type: %w", err) + } + result, err := c.schema.Validate(gojsonschema.NewBytesLoader(dataBytes)) + if err != nil { + return fmt.Errorf("failed to validate data against expected schema: %w", err) + } + return validationResultError(result) } // ValidateIsJSONArray will validate if the schema represents a JSON array. From fb63565c2521f840ff53739083c9cd39546cedf1 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 10:46:37 -0700 Subject: [PATCH 096/141] fix(go): reject init for non-bidi actions before committing to SSE A non-bidi action receiving init on a streaming HTTP request previously surfaced the rejection as an in-band SSE error event on a 200 once the init dispatch moved into the shared run path. Request-validation errors that are decidable from the request alone should fail with a real HTTP status before the response commits to a stream (status codes are immutable after headers, EventSource auto-reconnects on streams that end after an error event, and clients should not need to parse SSE to learn a request was malformed). checkInitSupported now guards the HTTP handler before the streaming branch, restoring the pre-stream 400 on every path, and also backs runActionWithOptionalInit so the contract and message stay in one place. --- go/genkit/reflection.go | 23 ++++++++++++++++++----- go/genkit/servers.go | 7 +++++++ go/genkit/servers_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 9ddff1fbc0..b28e0be48e 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -752,17 +752,30 @@ func runAction(ctx context.Context, g *Genkit, key string, input, init json.RawM }, nil } +// checkInitSupported rejects an init payload aimed at an action that cannot +// accept one: it returns INVALID_ARGUMENT when init carries a value and the +// action is not bidi, and nil otherwise. Transports call it before committing +// to a response shape (e.g. before writing SSE headers) so the rejection +// surfaces as a proper request error on every path. +func checkInitSupported(a api.Action, init json.RawMessage) error { + if base.HasJSONValue(init) { + if _, ok := a.(api.BidiAction); !ok { + return core.NewError(core.INVALID_ARGUMENT, "action %q does not accept init", a.Name()) + } + } + return nil +} + // runActionWithOptionalInit runs an action through its JSON surface, // dispatching to the bidi one-shot path when init carries a value. Init on a // non-bidi action is rejected with INVALID_ARGUMENT. Shared by the reflection // servers and the HTTP action handler so the init-acceptance contract stays // in one place. func runActionWithOptionalInit(ctx context.Context, a api.Action, input, init json.RawMessage, cb streamingCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { - if base.HasJSONValue(init) { - bidi, ok := a.(api.BidiAction) - if !ok { - return nil, core.NewError(core.INVALID_ARGUMENT, "action %q does not accept init", a.Name()) - } + if err := checkInitSupported(a, init); err != nil { + return nil, err + } + if bidi, ok := a.(api.BidiAction); ok && base.HasJSONValue(init) { return bidi.RunBidiJSON(ctx, input, cb, &api.BidiSessionOptions{Init: init}) } return a.RunJSONWithTelemetry(ctx, input, cb) diff --git a/go/genkit/servers.go b/go/genkit/servers.go index dbaad093bf..288baa95ac 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -175,6 +175,13 @@ func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http } } + // Rejected before the streaming branch commits to SSE headers, so a + // non-bidi action receiving init fails with a proper HTTP 400 on + // every path rather than an in-band SSE error event on a 200. + if err := checkInitSupported(a, body.Init); err != nil { + return err + } + run := func(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { r, err := runActionWithOptionalInit(ctx, a, input, body.Init, cb) if err != nil { diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index a18a109a40..629174e4b2 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -680,4 +680,32 @@ func TestHandlerBidiInitEnvelope(t *testing.T) { t.Errorf("status = %d, want %d; body = %s", resp.StatusCode, http.StatusBadRequest, string(respBody)) } }) + + t.Run("init on non-bidi flow returns 400 on streaming requests", func(t *testing.T) { + plainFlow := DefineFlow(g, "envelopePlainStream", + func(ctx context.Context, in string) (string, error) { + return "out:" + in, nil + }) + handler := Handler(plainFlow) + + // The rejection must happen before the handler commits to SSE: a + // streaming client should see an HTTP 400, not a 200 with an + // in-band error event. + body := `{"data":"hello","init":{"prefix":">> "}}` + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want %d; body = %s", resp.StatusCode, http.StatusBadRequest, string(respBody)) + } + if ct := resp.Header.Get("Content-Type"); strings.Contains(ct, "text/event-stream") { + t.Errorf("Content-Type = %q; response must not commit to SSE before rejecting init", ct) + } + }) } From ea4801cd3c63d006bce6bebcc8af48d507588c3f Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 10:57:34 -0700 Subject: [PATCH 097/141] fix(tools): add init to Zod source schemas and regenerate Python typings genkit-schema.json had the init property added by hand without updating the Zod schemas it is exported from. CI re-exports the schema from the Zod sources before regenerating Go code, so go/genkit/gen.go lost the Init field and the Go build failed. Define init on RunActionRequestSchema and SpanMetadataSchema so the export reproduces the committed schema, and regenerate the Python _typing.py that was missing the new field. --- genkit-tools/common/src/types/apis.ts | 6 ++++++ genkit-tools/common/src/types/trace.ts | 1 + py/packages/genkit/src/genkit/_core/_typing.py | 3 +++ 3 files changed, 10 insertions(+) diff --git a/genkit-tools/common/src/types/apis.ts b/genkit-tools/common/src/types/apis.ts index ce86d407d1..2acbb913c4 100644 --- a/genkit-tools/common/src/types/apis.ts +++ b/genkit-tools/common/src/types/apis.ts @@ -142,6 +142,12 @@ export const RunActionRequestSchema = z.object({ .any() .optional() .describe('An input with the type that this action expects.'), + init: z + .any() + .optional() + .describe( + 'Initialization parameters to establish long running session states.' + ), context: z .any() .optional() diff --git a/genkit-tools/common/src/types/trace.ts b/genkit-tools/common/src/types/trace.ts index 97542abb2f..94ae18d3e5 100644 --- a/genkit-tools/common/src/types/trace.ts +++ b/genkit-tools/common/src/types/trace.ts @@ -52,6 +52,7 @@ export const SpanMetadataSchema = z.object({ state: z.enum(['success', 'error']).optional(), input: z.any().optional(), output: z.any().optional(), + init: z.any().optional(), isRoot: z.boolean().optional(), metadata: z.record(z.string(), z.string()).optional(), path: z.string().optional(), diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index c9e7b6fa70..890c1f2732 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -613,6 +613,9 @@ class ReflectionRunActionParams(GenkitModel): runtime_id: str | None = None key: str = Field(..., description='Action key that consists of the action type and ID.') input: Any | None = Field(default=None, description='An input with the type that this action expects.') + init: Any | None = Field( + default=None, description='Initialization parameters to establish long running session states.' + ) context: Any | None = Field(default=None, description='Additional runtime context data (ex. auth context data).') telemetry_labels: TelemetryLabels | None = None stream: bool | None = None From 1d011e89344b386e27d3ffdaea67a8d6c8a39849 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 11:16:42 -0700 Subject: [PATCH 098/141] chore(ci): check schema freshness and all jsonschemagen outputs check-generated-go.sh only diffed go/ai/gen.go, so when the freshly exported genkit-schema.json regenerated go/genkit/gen.go differently, the check passed and the job failed later with a confusing build error. Now the script fails up front if export:schemas left genkit-schema.json dirty (with the export remedy, since rerunning jsonschemagen would not fix that), diffs every file the generator reports writing, and prints the offending diff. --- .../workflows/scripts/check-generated-go.sh | 50 ++++++++++++++++--- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/.github/workflows/scripts/check-generated-go.sh b/.github/workflows/scripts/check-generated-go.sh index 3185a479e7..ad66e44c50 100644 --- a/.github/workflows/scripts/check-generated-go.sh +++ b/.github/workflows/scripts/check-generated-go.sh @@ -13,15 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Check if generated Go code is up to date +# Check that the committed JSON schema and the generated Go code are up to +# date. Must run from the repo root, after the workflow's +# `npm --prefix genkit-tools run export:schemas` step. + +set -euo pipefail + +# export:schemas rewrites genkit-schema.json from the Zod schemas in +# genkit-tools. If it left the file dirty, the committed schema does not match +# its Zod sources and everything generated from it would silently drift, so +# fail with that remedy before checking the Go code. +if ! git diff --quiet -- genkit-tools/genkit-schema.json; then + echo "::error::genkit-tools/genkit-schema.json does not match the Zod schemas it is exported from." + echo "::error::Update the schemas under genkit-tools, then run 'npm --prefix genkit-tools run export:schemas' and commit the result." + git --no-pager diff -- genkit-tools/genkit-schema.json + exit 1 +fi + cd go/core -go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json ai +if ! gen_output=$(go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json ai 2>&1); then + echo "$gen_output" + echo "::error::jsonschemagen failed." + exit 1 +fi +echo "$gen_output" + +# Check every file the generator reported writing, not a hardcoded subset. +gen_files=$(sed -n 's/^jsonschemagen: wrote //p' <<<"$gen_output") +if [ -z "$gen_files" ]; then + echo "::error::Could not determine which files jsonschemagen wrote; has its output format changed?" + exit 1 +fi + +out_of_date=0 +for f in $gen_files; do + if ! git diff --quiet -- "$f"; then + echo "::error::Generated $f is out of date." + out_of_date=1 + fi +done -# Check if git detects changes to the generated file -if git diff --quiet ../ai/gen.go; then - exit 0 -else - echo "::error::Generated Go code is out of date. Please run:" +if [ "$out_of_date" -ne 0 ]; then + echo "::error::Please run the following and commit the result:" echo "::error::cd go/core && go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json ai" + git --no-pager diff -- $gen_files exit 1 -fi \ No newline at end of file +fi From 5b0e9e8bddafd6a0c4ba0a2d752a38fb25acd835 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 12:46:55 -0700 Subject: [PATCH 099/141] refactor(core): single-source the Genkit error wire shape as RuntimeError The canonical error wire shape {status, message, details} was stated three times: the AgentError Zod schema (named as if agent-specific, though it is the same form the JS runtime emits as HttpErrorWireFormat), the hand-written genkitErrorWire struct in go/core, and the schema GenkitError.JSONSchema infers from it. Nothing tied them together, and the Go copy had already drifted (status required vs optional in Zod). - Promote the schema out of agent.ts into error.ts as RuntimeErrorSchema and document the split against GenkitErrorSchema, which despite the name is the reflection API's HTTP error envelope, not the data-plane error shape. - Generate genkitErrorWire into go/core from the RuntimeError def (pkg/name directives in schemas.config) instead of hand-writing it. GenkitError's MarshalJSON/UnmarshalJSON/JSONSchema now delegate to a struct the schema freshness CI keeps in lockstep with the Zod source. Wire change: status gains omitempty, matching its optionality in the canonical schema (runtimes always set it in practice). - Python emits the class as GenkitRuntimeError via the generator's TRANSFORMATIONS map; RuntimeError would shadow the builtin exception and _typing classes are re-exported into the public genkit namespace. --- genkit-tools/common/src/types/agent.ts | 19 ++------- genkit-tools/common/src/types/error.ts | 27 +++++++++++++ genkit-tools/genkit-schema.json | 36 ++++++++--------- go/core/error.go | 11 +---- go/core/error_test.go | 5 ++- go/core/gen.go | 34 ++++++++++++++++ go/core/schemas.config | 40 +++++++++++++++++-- .../genkit/src/genkit/_core/_typing.py | 22 +++++----- py/tools/schema_to_typing/schema_to_typing.py | 2 + 9 files changed, 136 insertions(+), 60 deletions(-) create mode 100644 go/core/gen.go diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 90b77eda07..11e8b900e1 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -15,6 +15,7 @@ */ import { z } from 'zod'; +import { RuntimeErrorSchema } from './error'; import { MessageSchema, ModelResponseChunkSchema } from './model'; import { PartSchema, @@ -217,20 +218,6 @@ export const AgentResultSchema = z.object({ }); export type AgentResult = z.infer; -/** - * Zod schema for the canonical Genkit error wire shape - * (`{status, message, details}`), carried on agent outputs and snapshots. - */ -export const AgentErrorSchema = z.object({ - /** Canonical status name (e.g. `INTERNAL`, `FAILED_PRECONDITION`). */ - status: z.string().optional(), - /** Human-readable error message. */ - message: z.string(), - /** Optional structured details describing the failure. */ - details: z.any().optional(), -}); -export type AgentError = z.infer; - /** * Zod schema for agent output. */ @@ -280,7 +267,7 @@ export const AgentOutputSchema = z.object({ * category (e.g. `INVALID_ARGUMENT`, `FAILED_PRECONDITION`, * `INTERNAL`) so callers can still branch on it. */ - error: AgentErrorSchema.optional(), + error: RuntimeErrorSchema.optional(), }); export type AgentOutput = z.infer; @@ -371,7 +358,7 @@ export const SessionSnapshotSchema = z.object({ */ finishReason: AgentFinishReasonSchema.optional(), /** Structured failure information for a snapshot in `failed` status. */ - error: AgentErrorSchema.optional(), + error: RuntimeErrorSchema.optional(), /** * Conversation state captured at this point. Empty on a pending snapshot * (the live state is not yet committed); populated on terminal snapshots diff --git a/genkit-tools/common/src/types/error.ts b/genkit-tools/common/src/types/error.ts index 7de218a832..f554b5aea8 100644 --- a/genkit-tools/common/src/types/error.ts +++ b/genkit-tools/common/src/types/error.ts @@ -16,6 +16,33 @@ import { z } from 'zod'; +/** + * Zod schema for the canonical Genkit error wire shape + * (`{status, message, details}`). This is the form runtimes use when an + * error travels as data inside another value (e.g. agent outputs and + * session snapshots), matching `HttpErrorWireFormat` in the JS runtime + * and `GenkitError`'s wire form in the Go runtime. + * + * Not to be confused with {@link GenkitErrorSchema} below, which is the + * reflection API's HTTP error envelope. + */ +export const RuntimeErrorSchema = z.object({ + /** Canonical status name (e.g. `INTERNAL`, `FAILED_PRECONDITION`). */ + status: z.string().optional(), + /** Human-readable error message. */ + message: z.string(), + /** Optional structured details describing the failure. */ + details: z.any().optional(), +}); +export type RuntimeError = z.infer; + +/** + * Zod schema for the error envelope returned by a runtime's reflection + * API on failed HTTP requests, including debugging context (stack, + * trace ID) that the dev UI surfaces. Despite the name, this is a + * transport-layer shape; errors carried as data inside values use + * {@link RuntimeErrorSchema}. + */ export const GenkitErrorSchema = z.object({ message: z.string(), stack: z.string().optional(), diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 30e18a8092..65ce9aef18 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -28,22 +28,6 @@ ], "additionalProperties": false }, - "AgentError": { - "type": "object", - "properties": { - "status": { - "type": "string" - }, - "message": { - "type": "string" - }, - "details": {} - }, - "required": [ - "message" - ], - "additionalProperties": false - }, "AgentFinishReason": { "type": "string", "enum": [ @@ -144,7 +128,7 @@ "$ref": "#/$defs/AgentFinishReason" }, "error": { - "$ref": "#/$defs/AgentError" + "$ref": "#/$defs/RuntimeError" } }, "additionalProperties": false @@ -252,7 +236,7 @@ "$ref": "#/$defs/AgentFinishReason" }, "error": { - "$ref": "#/$defs/AgentError" + "$ref": "#/$defs/RuntimeError" }, "state": { "$ref": "#/$defs/SessionState" @@ -586,6 +570,22 @@ ], "additionalProperties": false }, + "RuntimeError": { + "type": "object", + "properties": { + "status": { + "type": "string" + }, + "message": { + "type": "string" + }, + "details": {} + }, + "required": [ + "message" + ], + "additionalProperties": false + }, "MiddlewareDesc": { "type": "object", "properties": { diff --git a/go/core/error.go b/go/core/error.go index c8dc20e2a3..759ae8abfa 100644 --- a/go/core/error.go +++ b/go/core/error.go @@ -55,16 +55,9 @@ type GenkitError struct { originalError error // The wrapped error, if any. } -// genkitErrorWire is the on-the-wire shape of a [GenkitError]; it -// matches the `RuntimeError` definition in the JSON schema. -type genkitErrorWire struct { - Status StatusName `json:"status"` - Message string `json:"message"` - Details map[string]any `json:"details,omitempty"` -} - // MarshalJSON encodes a GenkitError in the canonical Genkit error wire -// format: {status, message, details}. +// format: {status, message, details}. The wire shape ([genkitErrorWire]) +// is generated from the shared JSON schema's RuntimeError definition. func (e *GenkitError) MarshalJSON() ([]byte, error) { return json.Marshal(genkitErrorWire{ Status: e.Status, diff --git a/go/core/error_test.go b/go/core/error_test.go index f1f2bdefce..7366e0235a 100644 --- a/go/core/error_test.go +++ b/go/core/error_test.go @@ -177,7 +177,8 @@ func TestGenkitErrorJSONRoundtrip(t *testing.T) { if err != nil { t.Fatalf("Marshal: %v", err) } - want := `{"status":"NOT_FOUND","message":"missing","details":{"id":"abc"}}` + // Key order follows the generated wire struct's field order. + want := `{"details":{"id":"abc"},"message":"missing","status":"NOT_FOUND"}` if string(got) != want { t.Errorf("Marshal = %s, want %s", got, want) } @@ -189,7 +190,7 @@ func TestGenkitErrorJSONRoundtrip(t *testing.T) { if err != nil { t.Fatalf("Marshal: %v", err) } - want := `{"status":"NOT_FOUND","message":"missing"}` + want := `{"message":"missing","status":"NOT_FOUND"}` if string(got) != want { t.Errorf("Marshal = %s, want %s", got, want) } diff --git a/go/core/gen.go b/go/core/gen.go new file mode 100644 index 0000000000..98af8206b2 --- /dev/null +++ b/go/core/gen.go @@ -0,0 +1,34 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// This file was generated by jsonschemagen. DO NOT EDIT. + +package core + +// genkitErrorWire is the on-the-wire shape of a [GenkitError]: the +// canonical Genkit error format ({status, message, details}) shared +// across runtimes (RuntimeError in the JSON schema). GenkitError's +// MarshalJSON, UnmarshalJSON, and JSONSchema delegate to it; fields that +// exist for in-process use (HTTPCode, Source, the wrapped error) are not +// part of it. +type genkitErrorWire struct { + // Details is optional structured information describing the failure. + Details map[string]any `json:"details,omitempty"` + // Message is the human-readable error message. + Message string `json:"message"` + // Status is the canonical status name (e.g. INTERNAL, FAILED_PRECONDITION). + Status StatusName `json:"status,omitempty"` +} diff --git a/go/core/schemas.config b/go/core/schemas.config index 90effad657..8d14b3a386 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1400,10 +1400,42 @@ the original error category (e.g. INVALID_ARGUMENT, FAILED_PRECONDITION, INTERNAL) so callers can still branch on it. Nil otherwise. . -# AgentError mirrors the GenkitError wire shape. The fields that carry it -# (AgentOutput.error, SessionSnapshot.error) are overridden to -# *core.GenkitError, so the synthesized type is unused. -AgentError omit +# ---------------------------------------------------------------------------- +# RuntimeError +# ---------------------------------------------------------------------------- + +# RuntimeError is the canonical Genkit error wire shape. It is generated +# into core as the unexported struct backing GenkitError's MarshalJSON, +# UnmarshalJSON, and JSONSchema, so the wire format and the advertised +# schema are single-sourced from the Zod schema. The fields that carry +# it (AgentOutput.error, SessionSnapshot.error) are overridden to +# *core.GenkitError, which delegates to this type. +RuntimeError pkg core +RuntimeError name genkitErrorWire + +RuntimeError doc +genkitErrorWire is the on-the-wire shape of a [GenkitError]: the +canonical Genkit error format ({status, message, details}) shared +across runtimes (RuntimeError in the JSON schema). GenkitError's +MarshalJSON, UnmarshalJSON, and JSONSchema delegate to it; fields that +exist for in-process use (HTTPCode, Source, the wrapped error) are not +part of it. +. + +RuntimeError.status type StatusName +RuntimeError.status doc +Status is the canonical status name (e.g. INTERNAL, FAILED_PRECONDITION). +. + +RuntimeError.message noomitempty +RuntimeError.message doc +Message is the human-readable error message. +. + +RuntimeError.details type map[string]any +RuntimeError.details doc +Details is optional structured information describing the failure. +. # ---------------------------------------------------------------------------- # AgentStreamChunk diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 1c43c72f8c..f91aab8b75 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -133,15 +133,6 @@ class AbortSnapshotResponse(GenkitModel): status: SnapshotStatus | None = None -class AgentError(GenkitModel): - """Model for agenterror data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - status: str | None = None - message: str = Field(...) - details: Any | None = Field(default=None) - - class AgentInit(GenkitModel): """Model for agentinit data.""" @@ -178,7 +169,7 @@ class AgentOutput(GenkitModel): message: MessageData | None = None artifacts: list[Artifact] | None = None finish_reason: AgentFinishReason | None = None - error: AgentError | None = None + error: GenkitRuntimeError | None = None class AgentResult(GenkitModel): @@ -228,7 +219,7 @@ class SessionSnapshot(GenkitModel): event: SnapshotEvent = Field(...) status: SnapshotStatus | None = None finish_reason: AgentFinishReason | None = None - error: AgentError | None = None + error: GenkitRuntimeError | None = None state: SessionState | None = None @@ -346,6 +337,15 @@ class GenkitError(GenkitModel): data: Data | None = None +class GenkitRuntimeError(GenkitModel): + """Model for genkitruntimeerror data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + status: str | None = None + message: str = Field(...) + details: Any | None = Field(default=None) + + class MiddlewareDesc(GenkitModel): """Model for middlewaredesc data.""" diff --git a/py/tools/schema_to_typing/schema_to_typing.py b/py/tools/schema_to_typing/schema_to_typing.py index 9328d4224d..a71dd545d8 100644 --- a/py/tools/schema_to_typing/schema_to_typing.py +++ b/py/tools/schema_to_typing/schema_to_typing.py @@ -41,6 +41,8 @@ TRANSFORMATIONS = { 'Message': {'output_name': 'MessageData'}, 'GenerateActionOptions': {'suffix': 'Data', 'omit': ['messages']}, + # RuntimeError would shadow Python's builtin exception. + 'RuntimeError': {'output_name': 'GenkitRuntimeError'}, } From 9a3eee6f2b57361c134de3864199f1c3df0b8f3e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 13:18:39 -0700 Subject: [PATCH 100/141] fix(go/core): run no-init bidi actions and keep stack traces off the wire validateInit validated the zero init unconditionally, and for a pointer Init type the zero value marshals to JSON null, which can never satisfy the inferred object schema. Every no-init path on such actions failed with INVALID_ARGUMENT before the function ran: the unary surface (HTTP handler and Dev UI runs without an init field), RunBidiJSON, and StreamBidiJSON. Agents (Init = *AgentInit) could not start a fresh conversation through any JSON transport. Skip validation for nil inits; there is no value to validate and the function applies its defaults. A struct Init still validates its zero value, so a custom init schema the zero cannot satisfy keeps surfacing instead of silently defaulting. GenkitError.MarshalJSON serialized Details verbatim, including the debug stack NewError records under "stack", so any error embedded in a value, e.g. a failed agent invocation's AgentOutput, leaked process internals to HTTP clients and into client-managed session state. Omit the stack detail from the wire form; it stays on the in-process value for the consumers that want it (the reflection error envelope and the V2 runAction error data read the error directly). Also add TestHandlerAgent pinning the agent-over-HTTP contract: agents serve one turn per request through the standard action handler, data carries the turn input and init the session source, turn-tier failures resolve as HTTP 200 with a failed output carrying the last-good state, init-tier failures map to hard 4xx errors, and streaming delivers model chunks and turn lifecycle events over SSE. --- go/core/action.go | 15 +++ go/core/bidi.go | 12 ++- go/core/bidi_test.go | 139 ++++++++++++++++++++++++ go/core/error.go | 17 ++- go/core/error_test.go | 30 ++++++ go/genkit/servers_test.go | 221 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 431 insertions(+), 3 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index f6ca53e8ed..27ad39c386 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -194,6 +194,21 @@ func isUnitType[T any]() bool { return reflect.TypeFor[T]() == reflect.TypeFor[struct{}]() } +// isNilValue reports whether v is nil or a nil pointer, map, slice, or +// interface: a value that marshals to JSON null and carries nothing to +// validate against a schema. +func isNilValue(v any) bool { + if v == nil { + return true + } + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Pointer, reflect.Map, reflect.Slice, reflect.Interface: + return rv.IsNil() + default: + return false + } +} + // inferSchema returns the JSON schema inferred from T's zero value, or nil // for interface types, whose zero value carries no type information to infer // from. diff --git a/go/core/bidi.go b/go/core/bidi.go index 3f200a9fe2..8d6406b5b0 100644 --- a/go/core/bidi.go +++ b/go/core/bidi.go @@ -301,12 +301,20 @@ func (b *BidiAction[In, Out, Stream, Init]) decodeInit(opts *api.BidiSessionOpti // validateInit checks an init value against the action's InitSchema (if any), // resolving schema $refs through the registry first. Validation runs whenever -// InitSchema is present, even for the zero init value, so a required field -// surfaces as INVALID_ARGUMENT rather than silently defaulting. +// InitSchema is present, even for the zero init value, so a struct Init with +// required fields surfaces as INVALID_ARGUMENT rather than silently +// defaulting. A nil init carries no value to validate and is skipped: it is +// the zero value of a pointer Init type, produced whenever the action runs +// without init (the unary surface, a JSON transport request with no init +// field), and would otherwise always fail the inferred object schema as JSON +// null. The action function receives the nil and applies its defaults. func (b *BidiAction[In, Out, Stream, Init]) validateInit(init Init) error { if b.desc.InitSchema == nil { return nil } + if isNilValue(init) { + return nil + } schema, err := ResolveSchema(b.registry, b.desc.InitSchema) if err != nil { return NewError(INVALID_ARGUMENT, "invalid init schema for action %q: %v", b.desc.Key, err) diff --git a/go/core/bidi_test.go b/go/core/bidi_test.go index 479ff2593e..ddcf7f263a 100644 --- a/go/core/bidi_test.go +++ b/go/core/bidi_test.go @@ -377,6 +377,145 @@ func TestInitSchemaValidationAcceptsGoodInit(t *testing.T) { } } +// TestBidiNilInitSkipsValidation verifies that a nil init (the zero value of +// a pointer Init type) bypasses init schema validation on every no-init path. +// The inferred init schema describes the object form, which JSON null can +// never satisfy, so without the bypass an action with a pointer Init (e.g. an +// agent's *AgentInit) could not run at all through the unary surface or a +// JSON transport request that omits init. +func TestBidiNilInitSkipsValidation(t *testing.T) { + ctx := context.Background() + + type Config struct{ Prefix string } + + r := registry.New() + action := DefineBidiAction(r, "nil-init", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg *Config, inCh <-chan string, outCh chan<- string) (string, error) { + prefix := "default: " + if cfg != nil { + prefix = cfg.Prefix + } + var out string + for in := range inCh { + out = prefix + in + } + return out, nil + }) + + t.Run("unary surface runs with nil init", func(t *testing.T) { + out, err := action.RunJSON(ctx, json.RawMessage(`"hello"`), nil) + if err != nil { + t.Fatalf("RunJSON: %v", err) + } + if string(out) != `"default: hello"` { + t.Errorf("output = %s, want %q", out, "default: hello") + } + }) + + t.Run("JSON one-shot without init", func(t *testing.T) { + res, err := action.RunBidiJSON(ctx, json.RawMessage(`"hello"`), nil, nil) + if err != nil { + t.Fatalf("RunBidiJSON: %v", err) + } + if string(res.Result) != `"default: hello"` { + t.Errorf("output = %s, want %q", res.Result, "default: hello") + } + }) + + t.Run("JSON session without init", func(t *testing.T) { + conn, err := action.StreamBidiJSON(ctx, nil) + if err != nil { + t.Fatalf("StreamBidiJSON: %v", err) + } + if err := conn.Send(json.RawMessage(`"hello"`)); err != nil { + t.Fatalf("Send: %v", err) + } + conn.Close() + for _, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + } + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if string(out) != `"default: hello"` { + t.Errorf("output = %s, want %q", out, "default: hello") + } + }) + + t.Run("typed session with explicit nil init", func(t *testing.T) { + conn, err := action.StreamBidi(ctx, nil) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + conn.Close() + for _, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + } + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + }) + + t.Run("provided init still validates", func(t *testing.T) { + _, err := action.RunBidiJSON(ctx, json.RawMessage(`"hello"`), nil, + &api.BidiSessionOptions{Init: json.RawMessage(`{"Prefix": 42}`)}) + if err == nil { + t.Fatal("expected error for mistyped init, got nil") + } + }) +} + +// TestBidiZeroStructInitValidatedWithoutInit pins the struct-Init half of the +// no-init contract: running without init still validates the zero init value, +// so an init schema the zero value cannot satisfy surfaces as +// INVALID_ARGUMENT rather than silently defaulting. Authors opt into +// omissible init by choosing a pointer Init type. +func TestBidiZeroStructInitValidatedWithoutInit(t *testing.T) { + ctx := context.Background() + + type Config struct{ Prefix string } + + initSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "Prefix": map[string]any{"type": "string", "minLength": 1}, + }, + "required": []any{"Prefix"}, + } + + action := NewBidiAction( + "required-init", api.ActionTypeCustom, + &BidiActionOptions{InitSchema: initSchema}, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return cfg.Prefix, nil + }, + ) + + _, err := action.RunJSON(ctx, json.RawMessage(`"hello"`), nil) + if err == nil { + t.Fatal("expected validation error for zero init, got nil") + } + var gerr *GenkitError + if !errors.As(err, &gerr) || gerr.Status != INVALID_ARGUMENT { + t.Errorf("err = %v, want INVALID_ARGUMENT GenkitError", err) + } + + got, err := action.RunWithInit(ctx, Config{Prefix: ">> "}, "ignored", nil) + if err != nil { + t.Fatalf("RunWithInit: %v", err) + } + if got != ">> " { + t.Errorf("output = %q, want %q", got, ">> ") + } +} + func TestBidiConnectionSendAfterClose(t *testing.T) { ctx := context.Background() diff --git a/go/core/error.go b/go/core/error.go index 759ae8abfa..2117b3378e 100644 --- a/go/core/error.go +++ b/go/core/error.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "maps" "runtime/debug" "github.com/firebase/genkit/go/internal/base" @@ -58,11 +59,25 @@ type GenkitError struct { // MarshalJSON encodes a GenkitError in the canonical Genkit error wire // format: {status, message, details}. The wire shape ([genkitErrorWire]) // is generated from the shared JSON schema's RuntimeError definition. +// +// The stack trace [NewError] records under Details["stack"] is in-process +// diagnostics like HTTPCode and Source, not wire data: marshaling omits it +// so errors embedded in values (e.g. a failed agent invocation's output) +// do not leak process internals to clients. Consumers that want the stack +// (the reflection API's error envelope) read the error value directly. func (e *GenkitError) MarshalJSON() ([]byte, error) { + details := e.Details + if _, ok := details["stack"]; ok { + details = maps.Clone(details) + delete(details, "stack") + if len(details) == 0 { + details = nil + } + } return json.Marshal(genkitErrorWire{ Status: e.Status, Message: e.Message, - Details: e.Details, + Details: details, }) } diff --git a/go/core/error_test.go b/go/core/error_test.go index 7366e0235a..2c5959deba 100644 --- a/go/core/error_test.go +++ b/go/core/error_test.go @@ -196,6 +196,36 @@ func TestGenkitErrorJSONRoundtrip(t *testing.T) { } }) + t.Run("omits the auto-captured stack detail", func(t *testing.T) { + ge := NewError(NOT_FOUND, "missing") + ge.Details["id"] = "abc" + got, err := json.Marshal(ge) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + // The stack is in-process diagnostics; only the other details + // cross the wire. + want := `{"details":{"id":"abc"},"message":"missing","status":"NOT_FOUND"}` + if string(got) != want { + t.Errorf("Marshal = %s, want %s", got, want) + } + if _, ok := ge.Details["stack"]; !ok { + t.Error("marshaling must not mutate the in-process Details") + } + }) + + t.Run("omits details entirely when stack is the only entry", func(t *testing.T) { + ge := NewError(NOT_FOUND, "missing") + got, err := json.Marshal(ge) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `{"message":"missing","status":"NOT_FOUND"}` + if string(got) != want { + t.Errorf("Marshal = %s, want %s", got, want) + } + }) + t.Run("unmarshals and derives HTTPCode", func(t *testing.T) { raw := `{"status":"NOT_FOUND","message":"missing","details":{"id":"abc"}}` var ge GenkitError diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 6500d40f22..a83475a689 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -18,6 +18,7 @@ package genkit import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -26,6 +27,9 @@ import ( "strings" "testing" + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/x/streaming" @@ -709,3 +713,220 @@ func TestHandlerBidiInitEnvelope(t *testing.T) { } }) } + +// agentHTTPResult mirrors the AgentOutput fields the agent handler tests +// assert on, decoded from the handler's {"result": ...} envelope. +type agentHTTPResult struct { + FinishReason string `json:"finishReason"` + SessionID string `json:"sessionId"` + SnapshotID string `json:"snapshotId"` + Message *ai.Message `json:"message"` + State json.RawMessage `json:"state"` + Error *struct { + Status string `json:"status"` + Message string `json:"message"` + Details map[string]any `json:"details"` + } `json:"error"` +} + +// TestHandlerAgent verifies that agents, being bidi actions of type flow, +// serve one-turn-at-a-time over the standard action handler: data carries +// the turn's AgentInput, init carries the session source (state for +// client-managed agents, sessionId/snapshotId for server-managed ones), and +// the conversation resumes across requests. It also pins the error contract: +// turn-tier failures resolve as a 200 with a failed AgentOutput (so the +// caller keeps the last-good state), while init-tier failures (a rejected +// session source) are hard HTTP errors. +func TestHandlerAgent(t *testing.T) { + g := Init(context.Background()) + + // Replies "echo " where n is the number of messages the model saw, + // so resumed history is observable; fails when asked to. + DefineModel(g, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + last := req.Messages[len(req.Messages)-1] + if last.Role == ai.RoleUser && strings.Contains(last.Text(), "fail") { + return nil, core.NewError(core.RESOURCE_EXHAUSTED, "model on fire") + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: []*ai.Part{ai.NewTextPart("chunk")}}) + } + return &ai.ModelResponse{ + Message: ai.NewModelTextMessage(fmt.Sprintf("echo %d", len(req.Messages))), + FinishReason: ai.FinishReasonStop, + }, nil + }) + + DefineAgent[any](g, "agentClient", aix.FromInline(ai.WithModelName("test/echo"))) + + store, err := localstore.NewFileSessionStore[any](t.TempDir()) + if err != nil { + t.Fatal(err) + } + DefineAgent(g, "agentServer", aix.FromInline(ai.WithModelName("test/echo")), + aix.WithSessionStore(store), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ) + + // Agents register as flow actions, so they surface through ListFlows + // like any other handler-servable action. + handlerFor := func(t *testing.T, name string) http.HandlerFunc { + t.Helper() + for _, a := range ListFlows(g) { + if a.Name() == name { + return Handler(a) + } + } + t.Fatalf("agent %q not in ListFlows", name) + return nil + } + + post := func(t *testing.T, name, body string, stream bool) (int, string) { + t.Helper() + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + if stream { + req.Header.Set("Accept", "text/event-stream") + } + w := httptest.NewRecorder() + handlerFor(t, name)(w, req) + respBody, _ := io.ReadAll(w.Result().Body) + return w.Result().StatusCode, string(respBody) + } + + parseResult := func(t *testing.T, body string) agentHTTPResult { + t.Helper() + var envelope struct { + Result agentHTTPResult `json:"result"` + } + if err := json.Unmarshal([]byte(body), &envelope); err != nil { + t.Fatalf("unmarshal %q: %v", body, err) + } + return envelope.Result + } + + turn := func(text string) string { + return `{"data":{"message":{"role":"user","content":[{"text":"` + text + `"}]}}}` + } + turnWithInit := func(text, init string) string { + return `{"data":{"message":{"role":"user","content":[{"text":"` + text + `"}]}},"init":` + init + `}` + } + + t.Run("client-managed conversation across requests", func(t *testing.T) { + // Fresh turn: no init at all. + code, body := post(t, "agentClient", turn("hello"), false) + if code != http.StatusOK { + t.Fatalf("status = %d, body = %s", code, body) + } + res := parseResult(t, body) + if res.FinishReason != "stop" { + t.Errorf("finishReason = %q, want %q", res.FinishReason, "stop") + } + if res.SessionID == "" { + t.Error("missing sessionId") + } + if len(res.State) == 0 { + t.Fatal("client-managed output must carry state") + } + if got := res.Message.Text(); got != "echo 1" { + t.Errorf("message = %q, want %q", got, "echo 1") + } + + // Resume: round-trip the returned state through init. + code, body = post(t, "agentClient", turnWithInit("again", `{"state":`+string(res.State)+`}`), false) + if code != http.StatusOK { + t.Fatalf("resume status = %d, body = %s", code, body) + } + resumed := parseResult(t, body) + // The model saw the prior user/model exchange plus the new message. + if got := resumed.Message.Text(); got != "echo 3" { + t.Errorf("resumed message = %q, want %q", got, "echo 3") + } + if resumed.SessionID != res.SessionID { + t.Errorf("sessionId changed across resume: %q vs %q", resumed.SessionID, res.SessionID) + } + }) + + t.Run("server-managed conversation across requests", func(t *testing.T) { + code, body := post(t, "agentServer", turn("hello"), false) + if code != http.StatusOK { + t.Fatalf("status = %d, body = %s", code, body) + } + res := parseResult(t, body) + if len(res.State) != 0 { + t.Errorf("server-managed output must not inline state, got %s", res.State) + } + if res.SessionID == "" || res.SnapshotID == "" { + t.Fatalf("missing session/snapshot ID: %+v", res) + } + + code, body = post(t, "agentServer", turnWithInit("again", `{"sessionId":"`+res.SessionID+`"}`), false) + if code != http.StatusOK { + t.Fatalf("resume status = %d, body = %s", code, body) + } + resumed := parseResult(t, body) + if got := resumed.Message.Text(); got != "echo 3" { + t.Errorf("resumed message = %q, want %q", got, "echo 3") + } + }) + + t.Run("turn failure resolves as failed output with 200", func(t *testing.T) { + code, body := post(t, "agentClient", turn("fail"), false) + if code != http.StatusOK { + t.Fatalf("status = %d, want 200 (failure rides the output); body = %s", code, body) + } + res := parseResult(t, body) + if res.FinishReason != "failed" { + t.Errorf("finishReason = %q, want %q", res.FinishReason, "failed") + } + if res.Error == nil { + t.Fatalf("missing error in failed output: %s", body) + } + if res.Error.Status != "RESOURCE_EXHAUSTED" { + t.Errorf("error.status = %q, want RESOURCE_EXHAUSTED", res.Error.Status) + } + if res.Error.Message == "" { + t.Error("missing error.message") + } + if _, ok := res.Error.Details["stack"]; ok { + t.Error("error.details must not leak the in-process stack trace") + } + // The failed output still hands back the last-good state. + if len(res.State) == 0 { + t.Error("failed output must carry the last-good state") + } + }) + + t.Run("init-tier failures are hard HTTP errors", func(t *testing.T) { + code, body := post(t, "agentServer", turnWithInit("hi", `{"snapshotId":"nope"}`), false) + if code != http.StatusNotFound { + t.Errorf("unknown snapshot: status = %d, want %d; body = %s", code, http.StatusNotFound, body) + } + + code, body = post(t, "agentServer", turnWithInit("hi", `{"state":{"messages":[]}}`), false) + if code != http.StatusBadRequest { + t.Errorf("state on server-managed: status = %d, want %d; body = %s", code, http.StatusBadRequest, body) + } + + code, body = post(t, "agentClient", turnWithInit("hi", `{"sessionId":"abc"}`), false) + if code != http.StatusBadRequest { + t.Errorf("sessionId on client-managed: status = %d, want %d; body = %s", code, http.StatusBadRequest, body) + } + }) + + t.Run("streaming turn delivers chunks then result", func(t *testing.T) { + code, body := post(t, "agentClient", turn("stream"), true) + if code != http.StatusOK { + t.Fatalf("status = %d, body = %s", code, body) + } + if !strings.Contains(body, "modelChunk") { + t.Errorf("missing modelChunk event; body = %s", body) + } + if !strings.Contains(body, "turnEnd") { + t.Errorf("missing turnEnd event; body = %s", body) + } + if !strings.Contains(body, `"result"`) { + t.Errorf("missing final result event; body = %s", body) + } + }) +} From efe47c0057fd79fd2b804a514b34026865690605 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 13:18:48 -0700 Subject: [PATCH 101/141] feat(go/samples): add basic-agents-server sample serving agents over HTTP Demonstrates that agents serve one turn per request through the standard action handler with no extra wiring: ListFlows surfaces them like any flow, data carries the turn's AgentInput, and the optional init field carries the session source. Two agents cover the two session-state modes: "chat" persists snapshots to a FileSessionStore and resumes by sessionId, "statelessChat" round-trips the full state through the client. The doc comment walks through fresh, resumed, and streaming turns with curl, and documents the two failure tiers (failed turns as HTTP 200 with a failed output, rejected inits as hard 4xx). --- go/samples/basic-agents-server/main.go | 132 +++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 go/samples/basic-agents-server/main.go diff --git a/go/samples/basic-agents-server/main.go b/go/samples/basic-agents-server/main.go new file mode 100644 index 0000000000..63fc91f4fa --- /dev/null +++ b/go/samples/basic-agents-server/main.go @@ -0,0 +1,132 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates serving Genkit agents as plain HTTP endpoints. +// +// Agents are bidirectional streaming actions, but the standard action +// handler also runs them one turn per request: "data" carries the turn's +// input (the user message), and the optional "init" field carries the +// session source that lets a conversation span requests. +// +// Two agents show the two session-state modes: +// +// - "chat" configures a session store (server-managed state). Each turn +// persists a snapshot; the response carries sessionId and snapshotId, +// and a later request resumes the conversation by sending +// {"init": {"sessionId": ...}} (or {"snapshotId": ...} to resume from +// a specific point in history). +// - "statelessChat" has no store (client-managed state). The response +// carries the full conversation state; the client sends it back +// verbatim as {"init": {"state": ...}} on the next turn. The server +// keeps nothing between requests. +// +// To run: +// +// go run . +// +// Start a conversation (no init starts a fresh session): +// +// curl -X POST http://localhost:8080/chat \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"message": {"role": "user", "content": [{"text": "My name is Alex and I am planning a trip to Japan."}]}}}' +// +// Continue it, using the sessionId from the response: +// +// curl -X POST http://localhost:8080/chat \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"message": {"role": "user", "content": [{"text": "What is my name?"}]}}, "init": {"sessionId": "SESSION_ID"}}' +// +// Stream a turn's model chunks and lifecycle events as server-sent events: +// +// curl -N -X POST 'http://localhost:8080/chat?stream=true' \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"message": {"role": "user", "content": [{"text": "Suggest three day trips from Tokyo."}]}}}' +// +// For statelessChat, resume by round-tripping the returned state instead: +// +// curl -X POST http://localhost:8080/statelessChat \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"message": {"role": "user", "content": [{"text": "What is my name?"}]}}, "init": {"state": STATE_FROM_PREVIOUS_RESPONSE}}' +// +// Failures come in two tiers. A failed turn (e.g. the model call errors) +// still returns HTTP 200: the result reports finishReason "failed", a +// structured error ({status, message, details}), and the last-good +// conversation state (or a recovery snapshot ID), so the client can retry +// the turn without losing the conversation. A rejected init (an unknown +// session or snapshot ID, state sent to a store-backed agent) fails the +// request itself with a 4xx error before any turn runs. +package main + +import ( + "context" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the + // recommended practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + model := googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + }) + + // "chat" persists every conversation to a snapshot store, so a client + // only needs to hold on to the sessionId between requests. Snapshots + // land under ./.genkit/snapshots/chat/. + store, err := localstore.NewFileSessionStore[any]("./.genkit/snapshots/chat") + if err != nil { + log.Fatalf("creating session store: %v", err) + } + genkit.DefineAgent(g, "chat", + aix.FromInline( + ai.WithModel(model), + ai.WithSystem("You are a helpful travel assistant. Keep responses to a couple of sentences."), + ), + aix.WithSessionStore(store), + ) + + // "statelessChat" keeps no state on the server: each response carries + // the full conversation state and the client round-trips it on the next + // request. This suits deployments where the server must stay stateless. + genkit.DefineAgent[any](g, "statelessChat", + aix.FromInline( + ai.WithModel(model), + ai.WithSystem("You are a helpful travel assistant. Keep responses to a couple of sentences."), + ), + ) + + // Agents register as flow actions, so ListFlows surfaces them and the + // standard action handler serves them like any flow. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} From 28d59a63d85c85253688a035522333c693d39099 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 14:00:00 -0700 Subject: [PATCH 102/141] fix(go/genkit): recover handler panics in the V2 reflection server Handlers run action functions on dispatch goroutines, so a panic in any action took down the entire dev process along with it. Recover at the dispatch point, log the panic with its stack, and report a JSON-RPC server error to the caller, matching how the bidi engine recovers its function panics. The runtime keeps serving subsequent requests, which the new test pins by running a panicking flow and then a successful listActions on the same connection. --- go/genkit/reflection_v2.go | 14 ++++++++ go/genkit/reflection_v2_test.go | 64 +++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/go/genkit/reflection_v2.go b/go/genkit/reflection_v2.go index 158dc47134..4bc7515219 100644 --- a/go/genkit/reflection_v2.go +++ b/go/genkit/reflection_v2.go @@ -23,6 +23,7 @@ import ( "fmt" "log/slog" "os" + "runtime/debug" "strconv" "sync" "sync/atomic" @@ -319,6 +320,19 @@ func (s *reflectionServerV2) readLoop(ctx context.Context) { // notifications). Unknown methods with a request ID return "method not found"; // unknown notifications are logged and ignored. func (s *reflectionServerV2) handleRequest(ctx context.Context, req *jsonRPCMessage) { + // Handlers run action functions on this goroutine, where an unrecovered + // panic would take down the whole dev process along with the failing + // action. Report it as a server error instead, matching how the bidi + // engine recovers its function panics. + defer func() { + if r := recover(); r != nil { + slog.Error("reflection V2: handler panicked", "method", req.Method, "panic", r, "stack", string(debug.Stack())) + if req.ID != "" { + s.sendErrorResponse(req.ID, jsonRPCServerError, fmt.Sprintf("handler for %s panicked: %v", req.Method, r), nil) + } + } + }() + switch req.Method { case "listActions": s.handleListActions(ctx, req) diff --git a/go/genkit/reflection_v2_test.go b/go/genkit/reflection_v2_test.go index 468c89f57c..1b454b73fa 100644 --- a/go/genkit/reflection_v2_test.go +++ b/go/genkit/reflection_v2_test.go @@ -374,6 +374,70 @@ func TestReflectionServerV2_RunAction(t *testing.T) { } } +// TestReflectionServerV2_HandlerPanicRecovered verifies that a panicking +// action surfaces as a JSON-RPC error instead of crashing the process: +// handlers run action functions on dispatch goroutines, so without recovery +// one bad action would take the whole dev runtime down. +func TestReflectionServerV2_HandlerPanicRecovered(t *testing.T) { + m := newFakeManager(t) + defer m.close() + + g := Init(context.Background()) + DefineFlow(g, "panicky", func(ctx context.Context, _ string) (string, error) { + panic("boom") + }) + + ctx, cancel := startRuntime(t, g, m) + defer cancel() + + conn := m.waitForConnection(t) + m.ackRegister(t, ctx, conn) + + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "runAction", + "params": map[string]any{ + "key": "/flow/panicky", + "input": "x", + }, + "id": "p1", + }) + + var resp map[string]any + for { + msg := m.read(t, ctx, conn) + if msg["method"] == "runActionState" { + continue + } + resp = msg + break + } + if resp["id"] != "p1" { + t.Fatalf("id = %v, want p1", resp["id"]) + } + errObj, ok := resp["error"].(map[string]any) + if !ok { + t.Fatalf("expected error response for panicking action, got %v", resp) + } + if msg, _ := errObj["message"].(string); !strings.Contains(msg, "panicked") { + t.Errorf("error message = %q, want it to mention the panic", msg) + } + + // The runtime survived: a follow-up request still gets a response. + m.write(t, ctx, conn, map[string]any{ + "jsonrpc": "2.0", + "method": "listActions", + "id": "p2", + }) + resp2 := m.read(t, ctx, conn) + if resp2["id"] != "p2" { + t.Fatalf("follow-up id = %v, want p2", resp2["id"]) + } + if resp2["result"] == nil { + t.Errorf("follow-up result missing: %v", resp2) + } +} + func TestReflectionServerV2_StreamingRunAction(t *testing.T) { m := newFakeManager(t) defer m.close() From b487782367309ced673e4bdc4dff87493190f4ae Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 14:12:27 -0700 Subject: [PATCH 103/141] feat(go/exp): register agents under the agent action type Agents previously registered as flow actions so ListFlows would surface them for HTTP serving. Register them under api.ActionTypeAgent instead so they are their own action kind, and add genkit.ListAgents to surface them for HTTP exposure the same way ListFlows does for flows. Handler serves them unchanged since it accepts any action. --- go/ai/exp/agent.go | 12 +++++++----- go/ai/exp/agent_test.go | 6 +++--- go/genkit/genkit.go | 22 ++++++++++++++++++++-- go/genkit/servers_test.go | 14 ++++++++++---- go/samples/basic-agents-server/main.go | 7 ++++--- 5 files changed, 44 insertions(+), 17 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index bf28f4a705..f2cccfad6b 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -534,11 +534,13 @@ func DefineCustomAgent[Stream, State any]( } } - // Built on DefineBidiAction (rather than DefineBidiFlow) so the agent - // capability metadata can be set at construction time; actions must be - // immutable once registered. WithFlowContext below preserves the - // flow-context wrapping that makes core.Run work inside fn. - action := core.DefineBidiAction(r, name, api.ActionTypeFlow, + // Registered under ActionTypeAgent so agents surface as their own + // action kind rather than as flows (genkit.ListAgents vs ListFlows). + // Built on DefineBidiAction so the agent capability metadata can be + // set at construction time; actions must be immutable once registered. + // WithFlowContext below preserves the flow-context wrapping that makes + // core.Run work inside fn. + action := core.DefineBidiAction(r, name, api.ActionTypeAgent, &core.BidiActionOptions{ Metadata: map[string]any{"agent": agentMetadataFor(cfg.store)}, }, diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 5c575122ad..fbb194b8d7 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -3568,7 +3568,7 @@ func (minimalStore[State]) SaveSnapshot( } func TestAgent_AgentMetadata(t *testing.T) { - // Verify the metadata["agent"] payload on the flow's action descriptor + // Verify the metadata["agent"] payload on the agent's action descriptor // correctly reports stateManagement and abortable for each combination // of store capabilities. noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { @@ -3616,9 +3616,9 @@ func TestAgent_AgentMetadata(t *testing.T) { tc.define(reg, flowName) act := core.ResolveBidiActionFor[*AgentInput, *AgentOutput[testState], *AgentStreamChunk[testStatus], *AgentInit[testState]]( - reg, api.ActionTypeFlow, flowName) + reg, api.ActionTypeAgent, flowName) if act == nil { - t.Fatal("flow action not registered") + t.Fatal("agent action not registered") } desc := act.Desc() raw, ok := desc.Metadata["agent"] diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index d2baf1882b..83fe816fa1 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -399,7 +399,7 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // Experimental: This API is under active development and may change in any // minor version release. // -// An Agent is a stateful, multi-turn conversational flow. It builds on +// An Agent is a stateful, multi-turn conversational action. It builds on // bidirectional streaming to enable ongoing conversations where each turn's // input and output are streamed between client and server. The framework // handles session state, conversation history, and optional snapshot @@ -456,7 +456,7 @@ func DefineAgent[State any]( } // DefineCustomAgent defines an agent with full control over the conversation -// loop, registers it as a [core.Action] of type Flow, and returns an +// loop, registers it as an action of type agent, and returns an // [aix.Agent]. // // Experimental: This API is under active development and may change in any @@ -572,6 +572,24 @@ func ListFlows(g *Genkit) []api.Action { return flows } +// ListAgents returns a slice of all [api.Action] instances that represent +// agents registered with the Genkit instance `g`. Like [ListFlows], this is +// useful for introspection or for dynamically exposing agent endpoints in an +// HTTP server; an agent served via [Handler] runs one turn per request. +// +// Experimental: This API is under active development and may change in any +// minor version release. +func ListAgents(g *Genkit) []api.Action { + acts := listActions(g) + agents := []api.Action{} + for _, act := range acts { + if act.Type == api.ActionTypeAgent { + agents = append(agents, g.reg.LookupAction(act.Key)) + } + } + return agents +} + // ListTools returns a slice of all [ai.Tool] instances that are registered // with the Genkit instance `g`. This is useful for introspection and for // exposing tools to external systems like MCP servers. diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index a83475a689..8c9f14724a 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -768,16 +768,22 @@ func TestHandlerAgent(t *testing.T) { aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) - // Agents register as flow actions, so they surface through ListFlows - // like any other handler-servable action. + // Agents register under their own action type, so they surface through + // ListAgents (and not ListFlows) and Handler serves them like any other + // action. + for _, a := range ListFlows(g) { + if a.Name() == "agentClient" || a.Name() == "agentServer" { + t.Fatalf("agent %q unexpectedly listed as a flow", a.Name()) + } + } handlerFor := func(t *testing.T, name string) http.HandlerFunc { t.Helper() - for _, a := range ListFlows(g) { + for _, a := range ListAgents(g) { if a.Name() == name { return Handler(a) } } - t.Fatalf("agent %q not in ListFlows", name) + t.Fatalf("agent %q not in ListAgents", name) return nil } diff --git a/go/samples/basic-agents-server/main.go b/go/samples/basic-agents-server/main.go index 63fc91f4fa..b6dd06e9f7 100644 --- a/go/samples/basic-agents-server/main.go +++ b/go/samples/basic-agents-server/main.go @@ -122,10 +122,11 @@ func main() { ), ) - // Agents register as flow actions, so ListFlows surfaces them and the - // standard action handler serves them like any flow. + // Agents register under their own "agent" action type; ListAgents + // surfaces them and the standard action handler serves them one turn + // per request. mux := http.NewServeMux() - for _, a := range genkit.ListFlows(g) { + for _, a := range genkit.ListAgents(g) { mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) } log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) From 290d136965a1a8ddee7b9803a24df5f2b6af9fa6 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 12 Jun 2026 15:58:17 -0700 Subject: [PATCH 104/141] fix(go/core): reject absent input on one-shot bidi runs with a clear error A one-shot RunBidiJSON call with no input previously fell through to input schema validation and failed with a confusing shape mismatch (JSON null vs object schema). Input is required for one-shot runs since running with none would start the function and run zero turns; only a streaming session can start up and defer its first input. Reject the absent input up front with an INVALID_ARGUMENT that says so and points the caller at streaming sessions. --- go/core/api/action.go | 3 ++- go/core/bidi.go | 10 +++++++++- go/core/bidi_test.go | 42 +++++++++++++++++++++++++++++++++++++++ go/genkit/servers_test.go | 10 ++++++++++ 4 files changed, 63 insertions(+), 2 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index f82866a916..c9480e821e 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -63,7 +63,8 @@ type BidiAction interface { Action // RunBidiJSON runs the bidi action as a single one-shot call: input is // delivered as the only chunk on the input stream, outgoing chunks are - // forwarded to cb, and opts carries the session init. + // forwarded to cb, and opts carries the session init. Input is required; + // only a streaming session can defer it past startup. RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *BidiSessionOptions) (*ActionRunResult[json.RawMessage], error) // StreamBidiJSON starts a bidirectional streaming session using // JSON-encoded messages. diff --git a/go/core/bidi.go b/go/core/bidi.go index 8d6406b5b0..b791dbf0ca 100644 --- a/go/core/bidi.go +++ b/go/core/bidi.go @@ -217,10 +217,18 @@ func (b *BidiAction[In, Out, Stream, Init]) RunWithInit(ctx context.Context, ini // RunBidiJSON runs the bidi action as a single one-shot call: input is // delivered as the only chunk on the input stream, outgoing chunks are // forwarded to cb, and opts carries the session init. Returns an error if -// init fails to decode or validate. +// input is absent or init fails to decode or validate. // // Experimental: bidirectional streaming is experimental and subject to change. func (b *BidiAction[In, Out, Stream, Init]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], opts *api.BidiSessionOptions) (*api.ActionRunResult[json.RawMessage], error) { + // A one-shot run with no input would start the function and run zero + // turns, so reject it with a clearer message than the schema failure + // it would otherwise hit (JSON null never satisfies an object input + // schema). Deferring input past startup is a streaming session + // capability; see StreamBidiJSON. + if !base.HasJSONValue(input) { + return nil, NewError(INVALID_ARGUMENT, "action %q requires input for a one-shot run; open a streaming session to defer input", b.desc.Key) + } init, hasInit, err := b.decodeInit(opts) if err != nil { return nil, err diff --git a/go/core/bidi_test.go b/go/core/bidi_test.go index ddcf7f263a..d9068d8f61 100644 --- a/go/core/bidi_test.go +++ b/go/core/bidi_test.go @@ -269,6 +269,48 @@ func TestRunBidiJSONInvalidInit(t *testing.T) { } } +// TestRunBidiJSONRequiresInput verifies that a one-shot run with absent input +// is rejected up front with a message pointing at streaming sessions, rather +// than falling through to a schema validation failure. Only a streaming +// session can start up and defer its first input. +func TestRunBidiJSONRequiresInput(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string `json:"prefix"` + } + action := NewBidiAction( + "input-required", api.ActionTypeCustom, nil, + func(ctx context.Context, cfg Config, inCh <-chan string, outCh chan<- string) (string, error) { + return "ran", nil + }, + ) + + for name, input := range map[string]json.RawMessage{ + "nil input": nil, + "empty input": json.RawMessage(``), + "JSON null input": json.RawMessage(`null`), + } { + t.Run(name, func(t *testing.T) { + _, err := action.RunBidiJSON(ctx, input, nil, + &api.BidiSessionOptions{Init: json.RawMessage(`{"prefix":">> "}`)}) + if err == nil { + t.Fatal("expected error for absent input, got nil") + } + gerr, ok := err.(*GenkitError) + if !ok { + t.Fatalf("expected *GenkitError, got %T: %v", err, err) + } + if gerr.Status != INVALID_ARGUMENT { + t.Errorf("status = %v, want %v", gerr.Status, INVALID_ARGUMENT) + } + if !strings.Contains(gerr.Message, "streaming session") { + t.Errorf("message %q should point the caller at streaming sessions", gerr.Message) + } + }) + } +} + // TestStreamBidiJSONNullInit verifies that nil options and a JSON-null init // payload are both treated as no init (the zero Init value). func TestStreamBidiJSONNullInit(t *testing.T) { diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 8c9f14724a..7b9e13b537 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -918,6 +918,16 @@ func TestHandlerAgent(t *testing.T) { if code != http.StatusBadRequest { t.Errorf("sessionId on client-managed: status = %d, want %d; body = %s", code, http.StatusBadRequest, body) } + + // data is required for one-shot runs: only a streaming session can + // start up and defer its first input. + code, body = post(t, "agentClient", `{"init": {}}`, false) + if code != http.StatusBadRequest { + t.Errorf("init without data: status = %d, want %d; body = %s", code, http.StatusBadRequest, body) + } + if !strings.Contains(body, "streaming session") { + t.Errorf("init without data: body should point the caller at streaming sessions; body = %s", body) + } }) t.Run("streaming turn delivers chunks then result", func(t *testing.T) { From c1fecdc673065463112cad08c59f4b06d8dec4c8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 10:12:51 -0700 Subject: [PATCH 105/141] feat(go/exp): serve agents over HTTP and consolidate experimental helpers Make agents first-class bidi actions that register and serve like any other action, add an HTTP route layout for agents and flows, and fold the genkit/x helpers into genkit/exp. - Register the wrapped bidi action under the agent key rather than the Agent facade; the getSnapshot/abortSnapshot companions register independently and are recovered by key via the new genkit.LookupAction. - newSnapshotActions builds the companions without registering them, so an Agent can travel between registries as a unit. - Add WithDescription to carry a human-readable agent description on the action descriptor. - Add genkit/exp route builders (AgentRoutes/AllAgentRoutes/FlowRoutes/ AllFlowRoutes/Mount) layered over genkit.Handler, with fixed /agents and /flows base paths. - Move genkit/x's channel-based DefineStreamingFlow into genkit/exp under a shared package doc; remove the LatestSnapshot store helper. - Update the basic-agents samples to serve agents over HTTP. --- go/ai/exp/agent.go | 209 ++++++++++++++-- go/ai/exp/agent_test.go | 205 +++++++++++++++- go/ai/exp/localstore/file.go | 26 -- go/ai/exp/localstore/file_test.go | 79 ------ go/ai/exp/localstore/inmemory.go | 22 -- go/ai/exp/localstore/inmemory_test.go | 53 ----- go/ai/exp/option.go | 22 +- go/ai/exp/session.go | 30 ++- go/genkit/exp/doc.go | 35 +++ go/genkit/{x/genkit.go => exp/flow.go} | 11 +- .../{x/genkit_test.go => exp/flow_test.go} | 2 +- go/genkit/exp/routes.go | 164 +++++++++++++ go/genkit/exp/routes_test.go | 224 ++++++++++++++++++ go/genkit/genkit.go | 23 ++ go/genkit/servers.go | 59 +++-- go/genkit/servers_test.go | 197 ++++++++++++++- go/samples/basic-agents-server/main.go | 68 +++++- go/samples/basic-agents/cli.go | 137 ++++++----- go/samples/basic-agents/main.go | 141 ++++++----- 19 files changed, 1310 insertions(+), 397 deletions(-) create mode 100644 go/genkit/exp/doc.go rename go/genkit/{x/genkit.go => exp/flow.go} (91%) rename go/genkit/{x/genkit_test.go => exp/flow_test.go} (99%) create mode 100644 go/genkit/exp/routes.go create mode 100644 go/genkit/exp/routes_test.go diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index f2cccfad6b..a73ac12950 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -22,6 +22,7 @@ package exp import ( "context" + "encoding/json" "errors" "fmt" "iter" @@ -462,8 +463,28 @@ func (r Responder[Stream]) send(chunk *AgentStreamChunk[Stream]) { type AgentFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *SessionRunner[State]) (*AgentResult, error) // Agent is a bidirectional streaming agent with automatic snapshot management. +// +// Agent implements [api.BidiAction], so generic transports accept it +// directly (e.g. pass it to genkit.Handler to serve it over HTTP, one turn +// per request). The [Agent.Run], [Agent.RunText], and [Agent.StreamBidi] +// methods are typed conveniences over the same underlying action; both +// surfaces run the identical per-invocation runtime. +// +// Server-managed agents (those with a [SessionStore] configured) also +// register companion actions for the snapshot lifecycle, available via +// [Agent.GetSnapshotAction] and [Agent.AbortSnapshotAction] for serving +// alongside the agent, and expose the store itself via [Agent.Store]. type Agent[Stream, State any] struct { action *core.BidiAction[*AgentInput, *AgentOutput[State], *AgentStreamChunk[Stream], *AgentInit[State]] + // Companion actions, retained so transports can serve them without a + // registry lookup. Nil when the corresponding capability is absent; + // see newSnapshotActions. + getSnapshot api.Action + abortSnapshot api.Action + // store is the configured session store, or nil for a client-managed + // agent. Retained so callers can reach it via Store without threading + // a separate reference. + store SessionStore[State] } // Name returns the agent's registered name. This is also the name under @@ -473,6 +494,112 @@ func (a *Agent[Stream, State]) Name() string { return a.action.Name() } +// GetSnapshotAction returns the agent's getSnapshot companion action, +// which fetches a session snapshot by ID (input [GetSnapshotRequest], +// output [SessionSnapshot]). It returns nil when the agent is +// client-managed (no [SessionStore] configured): there is no server-side +// snapshot to fetch. +// +// Use it to expose snapshot polling over a transport (e.g. mount it with +// genkit.Handler next to the agent itself); local Go code should read +// from the store directly. +func (a *Agent[Stream, State]) GetSnapshotAction() api.Action { + return a.getSnapshot +} + +// AbortSnapshotAction returns the agent's abortSnapshot companion action, +// which asks the background work behind a pending snapshot (e.g. a +// detached invocation) to stop (input [AbortSnapshotRequest], output +// [AbortSnapshotResponse]). It returns nil when the agent has no +// [SessionStore] or the store does not implement [SnapshotAborter]. +// +// Use it to expose aborting over a transport (e.g. mount it with +// genkit.Handler next to the agent itself); local Go code should call the +// store's [SnapshotAborter.AbortSnapshot] directly. +func (a *Agent[Stream, State]) AbortSnapshotAction() api.Action { + return a.abortSnapshot +} + +// Store returns the [SessionStore] the agent was configured with via +// [WithSessionStore], or nil when the agent is client-managed (no store). +// It lets local Go code read and write snapshots directly given an agent +// reference, without threading a separate store variable. +// +// The store is returned as the [SessionStore] interface, not its concrete +// type; a caller needing a store-specific capability (e.g. +// [SnapshotAborter]) type-asserts for it. +func (a *Agent[Stream, State]) Store() SessionStore[State] { + return a.store +} + +// --- api.BidiAction implementation --- + +// Agent is itself an [api.BidiAction]: transports that accept an +// [api.Action] (or [api.BidiAction]) take an Agent directly. The bidi +// methods matter beyond mere interface completeness: generic transports +// type-assert to [api.BidiAction] to route session init (the wire +// counterpart of [WithSessionID], [WithSnapshotID], and [WithState]), so +// satisfying only [api.Action] would silently break session resume. +var _ api.BidiAction = (*Agent[any, any])(nil) + +// Register registers the agent's run action and any companion actions +// (getSnapshot, abortSnapshot) with the registry. Agents defined via +// [DefineAgent] or [DefineCustomAgent] are already registered; this +// exists so an agent can travel to another registry as a unit. An +// inline-defined prompt does not travel: the agent holds it directly, so +// execution is unaffected, but the prompt action stays in the registry it +// was defined in. +func (a *Agent[Stream, State]) Register(r api.Registry) { + // Register the wrapped bidi action under the agent key, the same way + // every other action registers itself; the registry holds a uniform + // api.BidiAction that the reflection servers, ListAgents, and the route + // builders consume without knowing about the Agent type. + // + // The companion actions register independently under their own keys, so + // registry consumers recover them by key (genkit.LookupAction) rather + // than by reaching through the agent action; see newSnapshotActions. + a.action.Register(r) + if a.getSnapshot != nil { + a.getSnapshot.Register(r) + } + if a.abortSnapshot != nil { + a.abortSnapshot.Register(r) + } +} + +// Desc returns the descriptor of the agent's run action. +func (a *Agent[Stream, State]) Desc() api.ActionDesc { + return a.action.Desc() +} + +// RunJSON runs a one-shot invocation with no init (a fresh session): +// input is the turn's [AgentInput] and the result is the final +// [AgentOutput]. To supply a session source, use [Agent.RunBidiJSON]. +func (a *Agent[Stream, State]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { + return a.action.RunJSON(ctx, input, cb) +} + +// RunJSONWithTelemetry is [Agent.RunJSON] with trace information on the +// result. +func (a *Agent[Stream, State]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (*api.ActionRunResult[json.RawMessage], error) { + return a.action.RunJSONWithTelemetry(ctx, input, cb) +} + +// RunBidiJSON runs a one-shot invocation whose session init (the wire +// counterpart of the [InvocationOption] values) rides in opts: input is +// delivered as the only chunk on the input stream and outgoing chunks are +// forwarded to cb. +func (a *Agent[Stream, State]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *api.BidiSessionOptions) (*api.ActionRunResult[json.RawMessage], error) { + return a.action.RunBidiJSON(ctx, input, cb, opts) +} + +// StreamBidiJSON starts a bidirectional streaming session using +// JSON-encoded messages. Local Go callers should prefer the typed +// [Agent.StreamBidi]. +func (a *Agent[Stream, State]) StreamBidiJSON(ctx context.Context, opts *api.BidiSessionOptions) (api.BidiJSONConnection, error) { + return a.action.StreamBidiJSON(ctx, opts) +} + // DefineAgent defines a prompt-backed agent and registers it. Each turn // renders the agent's prompt, appends conversation history, calls the // model with streaming, and updates session state. @@ -515,14 +642,24 @@ func DefineAgent[State any]( } } -// DefineCustomAgent defines an agent with full control over the -// conversation loop and registers it with the registry. fn receives a -// [Responder] for streaming output and a [SessionRunner] for turn and -// state management; call [SessionRunner.Run] to enter the per-turn loop. +// NewCustomAgent creates an agent with full control over the conversation +// loop without registering it. Register it later with the registry (e.g. +// genkit.RegisterAction), which also registers its companion actions; see +// [Agent.Register]. fn receives a [Responder] for streaming output and a +// [SessionRunner] for turn and state management; call [SessionRunner.Run] +// to enter the per-turn loop. // -// For agents backed by a prompt, use [DefineAgent] instead. -func DefineCustomAgent[Stream, State any]( - r api.Registry, +// This is the agent counterpart of [core.NewStreamingAction]: use it when +// the agent must outlive or precede a registry (e.g. built in a library, +// registered conditionally, or moved between registries). For the common +// case, [DefineCustomAgent] creates and registers in one step. +// +// There is no NewAgent counterpart for prompt-backed agents: a prompt is +// bound to the registry it renders and generates against, so a +// prompt-backed agent cannot be built before it has one. To get +// prompt-like behavior without registration, write a custom agent that +// renders and generates with your own [genkit.Genkit] inside fn. +func NewCustomAgent[Stream, State any]( name string, fn AgentFunc[Stream, State], opts ...AgentOption[State], @@ -530,20 +667,27 @@ func DefineCustomAgent[Stream, State any]( cfg := &agentOptions[State]{} for _, opt := range opts { if err := opt.applyAgent(cfg); err != nil { - panic(fmt.Errorf("DefineCustomAgent %q: %w", name, err)) + panic(fmt.Errorf("NewCustomAgent %q: %w", name, err)) } } - // Registered under ActionTypeAgent so agents surface as their own - // action kind rather than as flows (genkit.ListAgents vs ListFlows). - // Built on DefineBidiAction so the agent capability metadata can be - // set at construction time; actions must be immutable once registered. - // WithFlowContext below preserves the flow-context wrapping that makes - // core.Run work inside fn. - action := core.DefineBidiAction(r, name, api.ActionTypeAgent, - &core.BidiActionOptions{ - Metadata: map[string]any{"agent": agentMetadataFor(cfg.store)}, - }, + // Typed under ActionTypeAgent so agents surface as their own action + // kind rather than as flows (genkit.ListAgents vs ListFlows). Built on + // NewBidiAction so the agent capability metadata is set at construction + // time; actions must be immutable once registered. WithFlowContext + // below preserves the flow-context wrapping that makes core.Run work + // inside fn. + // + // metadata["agent"] carries the capability info for the Dev UI; + // metadata["description"], if set, is lifted to the descriptor's + // top-level Description by core (see core.newAction), the standard + // place reflective tooling reads an action's description. + metadata := map[string]any{"agent": agentMetadataFor(cfg.store)} + if cfg.description != "" { + metadata["description"] = cfg.description + } + action := core.NewBidiAction(name, api.ActionTypeAgent, + &core.BidiActionOptions{Metadata: metadata}, func( ctx context.Context, in *AgentInit[State], @@ -564,9 +708,34 @@ func DefineCustomAgent[Stream, State any]( return rt.run(ctx, fn) }) - registerSnapshotActions(r, name, cfg.store, cfg.transform) + getSnapshot, abortSnapshot := newSnapshotActions(name, cfg.store, cfg.transform) + + return &Agent[Stream, State]{ + action: action, + getSnapshot: getSnapshot, + abortSnapshot: abortSnapshot, + store: cfg.store, + } +} - return &Agent[Stream, State]{action: action} +// DefineCustomAgent defines an agent with full control over the +// conversation loop and registers it (and any companion actions) with the +// registry. fn receives a [Responder] for streaming output and a +// [SessionRunner] for turn and state management; call [SessionRunner.Run] +// to enter the per-turn loop. +// +// It is [NewCustomAgent] followed by [Agent.Register]. To build an agent +// without registering it, use [NewCustomAgent] directly. For agents backed +// by a prompt, use [DefineAgent] instead. +func DefineCustomAgent[Stream, State any]( + r api.Registry, + name string, + fn AgentFunc[Stream, State], + opts ...AgentOption[State], +) *Agent[Stream, State] { + a := NewCustomAgent(name, fn, opts...) + a.Register(r) + return a } // agentMetadataFor derives the [AgentMetadata] value attached to the diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index fbb194b8d7..35281ec3d6 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -3615,8 +3615,10 @@ func TestAgent_AgentMetadata(t *testing.T) { flowName := "metaFlow" tc.define(reg, flowName) - act := core.ResolveBidiActionFor[*AgentInput, *AgentOutput[testState], *AgentStreamChunk[testStatus], *AgentInit[testState]]( - reg, api.ActionTypeAgent, flowName) + // The registry holds the agent's bidi action under the agent key + // (see Agent.Register); resolve it as a plain api.Action and read + // its descriptor. + act := reg.LookupAction(api.NewKey(api.ActionTypeAgent, "", flowName)) if act == nil { t.Fatal("agent action not registered") } @@ -3685,6 +3687,205 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { }) } +// TestAgent_CompanionActionAccessors verifies the agent ref hands out the +// same companion actions the registry holds, so transports can mount them +// on custom routes without registry lookups, and that the accessors mirror +// the store's capabilities by returning nil for actions that were not +// registered. +func TestAgent_CompanionActionAccessors(t *testing.T) { + noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, nil + } + + t.Run("no store → no companions", func(t *testing.T) { + reg := newTestRegistry(t) + af := DefineCustomAgent(reg, "noCompanions", noopFn) + if got := af.GetSnapshotAction(); got != nil { + t.Errorf("GetSnapshotAction() = %v, want nil", got) + } + if got := af.AbortSnapshotAction(); got != nil { + t.Errorf("AbortSnapshotAction() = %v, want nil", got) + } + }) + + t.Run("store without aborter → getSnapshot only", func(t *testing.T) { + reg := newTestRegistry(t) + af := DefineCustomAgent(reg, "getOnly", noopFn, + WithSessionStore[testState](minimalStore[testState]{})) + if af.GetSnapshotAction() == nil { + t.Error("GetSnapshotAction() = nil, want action") + } + if got := af.AbortSnapshotAction(); got != nil { + t.Errorf("AbortSnapshotAction() = %v, want nil", got) + } + }) + + t.Run("aborter store → both, identical to the registered actions", func(t *testing.T) { + reg := newTestRegistry(t) + af := DefineCustomAgent(reg, "bothCompanions", noopFn, + WithSessionStore(newTestInMemStore[testState]())) + if got, want := af.GetSnapshotAction(), reg.LookupAction("/agent-snapshot/bothCompanions"); got == nil || got != want { + t.Errorf("GetSnapshotAction() = %v, want registered action %v", got, want) + } + if got, want := af.AbortSnapshotAction(), reg.LookupAction("/agent-abort/bothCompanions"); got == nil || got != want { + t.Errorf("AbortSnapshotAction() = %v, want registered action %v", got, want) + } + }) +} + +// TestAgent_Store verifies the agent hands back the store it was +// configured with (so local Go code need not thread a separate reference), +// and nil when the agent is client-managed. +func TestAgent_Store(t *testing.T) { + noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, nil + } + + t.Run("returns the configured store", func(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := DefineCustomAgent(reg, "withStore", noopFn, WithSessionStore[testState](store)) + if got := af.Store(); got != SessionStore[testState](store) { + t.Errorf("Store() = %v, want the configured store %v", got, store) + } + // The returned store is usable directly, and store-specific + // capabilities are reachable by type assertion. + if _, ok := af.Store().(SnapshotAborter); !ok { + t.Error("expected the configured store to satisfy SnapshotAborter") + } + }) + + t.Run("nil for a client-managed agent", func(t *testing.T) { + reg := newTestRegistry(t) + af := DefineCustomAgent(reg, "noStore", noopFn) + if got := af.Store(); got != nil { + t.Errorf("Store() = %v, want nil for a client-managed agent", got) + } + }) +} + +// TestAgent_Description verifies that WithDescription lands on the agent +// action's descriptor as the standard top-level Description (lifted from +// metadata["description"] by core), so reflective tooling and local +// callers read it the same way they read any other action's description. +func TestAgent_Description(t *testing.T) { + noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, nil + } + + t.Run("set via WithDescription", func(t *testing.T) { + reg := newTestRegistry(t) + const want = "A concise test agent." + af := DefineCustomAgent(reg, "described", noopFn, WithDescription[testState](want)) + if got := af.Desc().Description; got != want { + t.Errorf("Desc().Description = %q, want %q", got, want) + } + // It must also be in the metadata map, the place core lifts it + // from and where the wire descriptor carries it. + if got, _ := af.Desc().Metadata["description"].(string); got != want { + t.Errorf("Desc().Metadata[\"description\"] = %q, want %q", got, want) + } + }) + + t.Run("empty when unset", func(t *testing.T) { + reg := newTestRegistry(t) + af := DefineCustomAgent(reg, "undescribed", noopFn) + if got := af.Desc().Description; got != "" { + t.Errorf("Desc().Description = %q, want empty", got) + } + if _, ok := af.Desc().Metadata["description"]; ok { + t.Error("expected no description key in metadata when unset") + } + }) + + t.Run("rejects a second WithDescription", func(t *testing.T) { + reg := newTestRegistry(t) + defer func() { + if recover() == nil { + t.Error("expected a panic when WithDescription is set twice") + } + }() + DefineCustomAgent(reg, "twice", noopFn, + WithDescription[testState]("first"), WithDescription[testState]("second")) + }) +} + +// TestAgent_RegisterCarriesCompanions verifies that registering an agent +// ref into another registry brings the companion actions along, so the +// agent travels as a unit (see Agent.Register). +func TestAgent_RegisterCarriesCompanions(t *testing.T) { + reg := newTestRegistry(t) + af := DefineCustomAgent(reg, "mover", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, nil + }, + WithSessionStore(newTestInMemStore[testState]()), + ) + + reg2 := newTestRegistry(t) + af.Register(reg2) + for _, key := range []string{"/agent/mover", "/agent-snapshot/mover", "/agent-abort/mover"} { + if reg2.LookupAction(key) == nil { + t.Errorf("action %q missing from the registry after Register", key) + } + } + + // The agent key resolves to the bidi run action, not a companion-bearing + // facade: consumers recover the companions by their own keys (as the loop + // above and the route builders do), not by asserting an interface on the + // agent action. + runAction := reg2.LookupAction("/agent/mover") + if _, ok := runAction.(api.BidiAction); !ok { + t.Errorf("registered agent action = %T, want an api.BidiAction", runAction) + } + if _, ok := runAction.(interface{ GetSnapshotAction() api.Action }); ok { + t.Error("registered agent action should be the bidi run action, not the Agent facade") + } +} + +// TestNewCustomAgent_UnregisteredUntilRegister verifies the non-registering +// constructor: the agent is fully usable before it touches a registry, and +// registering it later (the genkit.RegisterAction path) surfaces the run +// action and its companions together. +func TestNewCustomAgent_UnregisteredUntilRegister(t *testing.T) { + ctx := context.Background() + store := newTestInMemStore[testState]() + + af := NewCustomAgent("standalone", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("hi")) + return &TurnResult{FinishReason: AgentFinishReasonStop}, nil + }) + }, + WithSessionStore(store), + ) + + // Companion refs are wired at construction, before any registry. + if af.GetSnapshotAction() == nil || af.AbortSnapshotAction() == nil { + t.Fatal("companion actions should be built by NewCustomAgent before registration") + } + + // The agent runs without ever being registered: the runtime never + // consults a registry for a custom agent. + out, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText on unregistered agent: %v", err) + } + if out.FinishReason != AgentFinishReasonStop { + t.Errorf("finishReason = %q, want %q", out.FinishReason, AgentFinishReasonStop) + } + + // Registering later brings the whole unit into the registry. + reg := newTestRegistry(t) + af.Register(reg) + for _, key := range []string{"/agent/standalone", "/agent-snapshot/standalone", "/agent-abort/standalone"} { + if reg.LookupAction(key) == nil { + t.Errorf("action %q missing after Register", key) + } + } +} + func TestAgent_AbortAction_NotFound(t *testing.T) { // The store's "not found" sentinel (empty status, nil error) must // surface as a core.NOT_FOUND GenkitError on the abort companion diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index 19eb303183..e386a6a470 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -237,32 +237,6 @@ func (s *FileSessionStore[State]) snapshotFilesNewestFirst() ([]string, error) { return names, nil } -// LatestSnapshot returns the snapshot whose backing file has the most -// recent on-disk modification time, or nil if the directory has no -// snapshots yet. It is not part of the [exp.SessionStore] interface; it -// is a convenience for UIs and CLIs that need to surface "where did I -// leave off" without indexing the directory themselves. -// -// Selecting by mtime avoids parsing every file: for snapshots written by -// this package, mtime and [exp.SessionSnapshot.UpdatedAt] advance -// together, so the result matches a sort by UpdatedAt; if a file is -// touched externally, mtime wins. Files that fail to stat or parse are -// skipped and the scan falls back to the next-newest candidate. -func (s *FileSessionStore[State]) LatestSnapshot(ctx context.Context) (*exp.SessionSnapshot[State], error) { - names, err := s.snapshotFilesNewestFirst() - if err != nil { - return nil, err - } - for _, name := range names { - snap, err := s.GetSnapshot(ctx, strings.TrimSuffix(name, ".json")) - if err != nil || snap == nil { - continue - } - return snap, nil - } - return nil, nil -} - // AbortSnapshot atomically flips a pending snapshot to aborted. If the // snapshot is already terminal the existing status is returned unchanged. // Returns an empty status if the snapshot is not found. diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index b864cd67ab..06948d8b49 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -396,85 +396,6 @@ func TestFileSessionStore(t *testing.T) { var _ exp.SessionStore[testState] = (*FileSessionStore[testState])(nil) var _ exp.SnapshotAborter = (*FileSessionStore[testState])(nil) }) - - t.Run("LatestSnapshotEmpty", func(t *testing.T) { - store := newFileStore(t) - latest, err := store.LatestSnapshot(context.Background()) - if err != nil { - t.Fatalf("LatestSnapshot: %v", err) - } - if latest != nil { - t.Errorf("expected nil on empty store, got %+v", latest) - } - }) - - t.Run("LatestSnapshotMissingDir", func(t *testing.T) { - // Construct a store but then remove the dir to simulate a - // pre-use store. LatestSnapshot should treat "no dir" as - // "no snapshots" rather than error. - dir := filepath.Join(t.TempDir(), "nope") - store, err := NewFileSessionStore[testState](dir) - if err != nil { - t.Fatalf("NewFileSessionStore: %v", err) - } - if err := os.RemoveAll(dir); err != nil { - t.Fatalf("RemoveAll: %v", err) - } - latest, err := store.LatestSnapshot(context.Background()) - if err != nil { - t.Fatalf("LatestSnapshot: %v", err) - } - if latest != nil { - t.Errorf("expected nil on missing dir, got %+v", latest) - } - }) - - t.Run("LatestSnapshotReturnsMostRecent", func(t *testing.T) { - store := newFileStore(t) - ctx := context.Background() - for _, id := range []string{"snap-1", "snap-2", "snap-3"} { - if _, err := store.SaveSnapshot(ctx, id, - func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil - }); err != nil { - t.Fatalf("SaveSnapshot %q: %v", id, err) - } - // Force monotonically increasing UpdatedAt — the store - // stamps to wall-clock time and successive writes within - // the same tick can tie. - time.Sleep(2 * time.Millisecond) - } - latest, err := store.LatestSnapshot(ctx) - if err != nil { - t.Fatalf("LatestSnapshot: %v", err) - } - if latest == nil || latest.SnapshotID != "snap-3" { - t.Errorf("expected latest=snap-3, got %+v", latest) - } - }) - - t.Run("LatestSnapshotSkipsCorruptedFiles", func(t *testing.T) { - store := newFileStore(t) - ctx := context.Background() - good, err := store.SaveSnapshot(ctx, "good", - func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil - }) - if err != nil { - t.Fatalf("SaveSnapshot good: %v", err) - } - // Drop an unparseable file alongside. - if err := os.WriteFile(filepath.Join(store.dir, "broken.json"), []byte("not json"), 0o600); err != nil { - t.Fatalf("seed broken file: %v", err) - } - latest, err := store.LatestSnapshot(ctx) - if err != nil { - t.Fatalf("LatestSnapshot: %v", err) - } - if latest == nil || latest.SnapshotID != good.SnapshotID { - t.Errorf("expected to skip broken.json and return %q, got %+v", good.SnapshotID, latest) - } - }) } // TestFileSessionStore_FinishReasonPersistsAcrossReopen verifies that a diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go index 100eda5a70..ca73f1a8d4 100644 --- a/go/ai/exp/localstore/inmemory.go +++ b/go/ai/exp/localstore/inmemory.go @@ -95,28 +95,6 @@ func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessi return copySnapshot(latest) } -// LatestSnapshot returns the snapshot with the most recent -// [exp.SessionSnapshot.UpdatedAt] in the store, or nil if there are none. -// -// This is not part of the [exp.SessionStore] interface; it is an -// InMemorySessionStore-specific convenience that mirrors -// [FileSessionStore.LatestSnapshot] so callers that swap stores during -// tests don't have to special-case the in-memory implementation. -func (s *InMemorySessionStore[State]) LatestSnapshot(_ context.Context) (*exp.SessionSnapshot[State], error) { - s.mu.RLock() - defer s.mu.RUnlock() - var latest *exp.SessionSnapshot[State] - for _, snap := range s.snapshots { - if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) { - latest = snap - } - } - if latest == nil { - return nil, nil - } - return copySnapshot(latest) -} - // AbortSnapshot atomically flips a pending snapshot to aborted. If the // snapshot is already terminal the existing status is returned unchanged. // Returns an empty status if the snapshot is not found. diff --git a/go/ai/exp/localstore/inmemory_test.go b/go/ai/exp/localstore/inmemory_test.go index 5bdc4604b1..07c27eafe8 100644 --- a/go/ai/exp/localstore/inmemory_test.go +++ b/go/ai/exp/localstore/inmemory_test.go @@ -156,59 +156,6 @@ func TestInMemorySessionStore(t *testing.T) { var _ exp.SessionStore[testState] = (*InMemorySessionStore[testState])(nil) var _ exp.SnapshotAborter = (*InMemorySessionStore[testState])(nil) }) - - t.Run("LatestSnapshotEmpty", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - latest, err := store.LatestSnapshot(context.Background()) - if err != nil { - t.Fatalf("LatestSnapshot: %v", err) - } - if latest != nil { - t.Errorf("expected nil on empty store, got %+v", latest) - } - }) - - t.Run("LatestSnapshotReturnsMostRecent", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - ctx := context.Background() - for _, id := range []string{"snap-1", "snap-2", "snap-3"} { - if _, err := store.SaveSnapshot(ctx, id, - func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil - }); err != nil { - t.Fatalf("SaveSnapshot %q: %v", id, err) - } - // Wall-clock UpdatedAt can tie within a tick; force order. - time.Sleep(2 * time.Millisecond) - } - latest, err := store.LatestSnapshot(ctx) - if err != nil { - t.Fatalf("LatestSnapshot: %v", err) - } - if latest == nil || latest.SnapshotID != "snap-3" { - t.Errorf("expected latest=snap-3, got %+v", latest) - } - }) - - t.Run("LatestSnapshotReturnsCopy", func(t *testing.T) { - store := NewInMemorySessionStore[testState]() - ctx := context.Background() - if _, err := store.SaveSnapshot(ctx, "snap-1", - func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, - State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, - }, nil - }); err != nil { - t.Fatalf("SaveSnapshot: %v", err) - } - first, _ := store.LatestSnapshot(ctx) - first.State.Custom.Counter = 999 - second, _ := store.LatestSnapshot(ctx) - if second.State.Custom.Counter != 1 { - t.Errorf("expected counter=1 (isolation), got %d", second.State.Custom.Counter) - } - }) } func TestInMemorySessionStore_SessionIDs(t *testing.T) { diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 399c3f5735..f52945d12e 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -47,9 +47,10 @@ type AgentOption[State any] interface { type StateTransform[State any] = func(ctx context.Context, state *SessionState[State]) *SessionState[State] type agentOptions[State any] struct { - store SessionStore[State] - callback SnapshotCallback[State] - transform StateTransform[State] + store SessionStore[State] + callback SnapshotCallback[State] + transform StateTransform[State] + description string } func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { @@ -71,6 +72,12 @@ func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { } opts.transform = o.transform } + if o.description != "" { + if opts.description != "" { + return errors.New("cannot set description more than once (WithDescription)") + } + opts.description = o.description + } return nil } @@ -111,6 +118,15 @@ func WithStateTransform[State any](transform StateTransform[State]) AgentOption[ return &agentOptions[State]{transform: transform} } +// WithDescription sets a human-readable description of the agent. It is +// stored on the agent action's descriptor (read back via [Agent.Desc] and +// surfaced in the Dev UI's action listing), the same place every other +// primitive carries its description, so reflective tooling can render it +// without a separate field. +func WithDescription[State any](description string) AgentOption[State] { + return &agentOptions[State]{description: description} +} + // --- InvocationOption --- // InvocationOption configures an agent invocation (StreamBidi, Run, or RunText). diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index f426f02815..bc8d3080d6 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -216,32 +216,35 @@ func cloneArtifacts(arts []*Artifact) []*Artifact { // --- Snapshot companion actions --- -// registerSnapshotActions registers the agent's companion actions when -// the agent has a [SessionStore] configured: +// newSnapshotActions creates the agent's companion actions, without +// registering them, when the agent has a [SessionStore] configured: // // - The agent's name under [api.ActionTypeAgentSnapshot] — getSnapshot, // the remote counterpart to [SessionStore.GetSnapshot] for Dev UI and // non-Go clients. Local Go callers use the store reference directly. // // - The agent's name under [api.ActionTypeAgentAbort] — abortSnapshot, -// registered only when the store also implements [SnapshotAborter] -// (which bundles both the abort trigger and the status-change -// subscription needed for the runtime to react). +// created only when the store also implements [SnapshotAborter] (which +// bundles both the abort trigger and the status-change subscription +// needed for the runtime to react). // // When the agent is client-managed (no store configured), neither action -// is registered: there is no server-side snapshot to fetch or abort. +// is created: there is no server-side snapshot to fetch or abort. // Surfacing actions only when the underlying capabilities exist keeps the // reflected API aligned with what the agent can actually do. -func registerSnapshotActions[State any]( - r api.Registry, +// +// The [Agent] retains the returned actions (an absent one is nil) and +// registers them alongside its run action; see [Agent.Register], +// [Agent.GetSnapshotAction], and [Agent.AbortSnapshotAction]. +func newSnapshotActions[State any]( agentName string, store SessionStore[State], transform StateTransform[State], -) { +) (getSnapshot, abortSnapshot api.Action) { if store == nil { - return + return nil, nil } - core.DefineAction(r, agentName, api.ActionTypeAgentSnapshot, nil, nil, + getSnapshotAction := core.NewAction(agentName, api.ActionTypeAgentSnapshot, nil, nil, func(ctx context.Context, req *GetSnapshotRequest) (*SessionSnapshot[State], error) { if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId is required") @@ -286,9 +289,9 @@ func registerSnapshotActions[State any]( if !ok { // Store doesn't support the abort lifecycle. Don't surface the // action. - return + return getSnapshotAction, nil } - core.DefineAction(r, agentName, api.ActionTypeAgentAbort, nil, nil, + abortSnapshotAction := core.NewAction(agentName, api.ActionTypeAgentAbort, nil, nil, func(ctx context.Context, req *AbortSnapshotRequest) (*AbortSnapshotResponse, error) { if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "abortSnapshot: snapshotId is required") @@ -302,6 +305,7 @@ func registerSnapshotActions[State any]( } return &AbortSnapshotResponse{SnapshotID: req.SnapshotID, Status: status}, nil }) + return getSnapshotAction, abortSnapshotAction } // --- Session --- diff --git a/go/genkit/exp/doc.go b/go/genkit/exp/doc.go new file mode 100644 index 0000000000..7f2e6e4169 --- /dev/null +++ b/go/genkit/exp/doc.go @@ -0,0 +1,35 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Package exp holds experimental genkit helpers that are still taking shape. +It currently provides: + + - An HTTP route layout for serving agents and flows: the [Route] value, the + [AgentRoutes] / [AllAgentRoutes] / [FlowRoutes] / [AllFlowRoutes] builders, + and [Mount]. The handlers themselves come from the stable genkit package + ([genkit.Handler]); this package only lays out which paths map to which + actions, so the routing layer can evolve without touching genkit's stable + surface. + + - A channel-based streaming flow constructor, [DefineStreamingFlow]: an + alternative to the callback-based [genkit.DefineStreamingFlow] for logic + that is more naturally expressed by writing chunks to a channel. + +APIs in this package are under active development and may change in any minor +version release. +*/ +package exp diff --git a/go/genkit/x/genkit.go b/go/genkit/exp/flow.go similarity index 91% rename from go/genkit/x/genkit.go rename to go/genkit/exp/flow.go index 9480f75149..f2987483e9 100644 --- a/go/genkit/x/genkit.go +++ b/go/genkit/exp/flow.go @@ -14,14 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -// Package x provides experimental Genkit APIs. -// -// APIs in this package are under active development and may change in any -// minor version release. Use with caution in production environments. -// -// When these APIs stabilize, they will be moved to the genkit package -// and these exports will be deprecated. -package x +package exp import ( "context" @@ -59,7 +52,7 @@ type StreamingFunc[In, Out, Stream any] = func(ctx context.Context, input In, st // // Example: // -// countdown := x.DefineStreamingFlow(g, "countdown", +// countdown := exp.DefineStreamingFlow(g, "countdown", // func(ctx context.Context, start int, streamCh chan<- int) (string, error) { // for i := start; i > 0; i-- { // select { diff --git a/go/genkit/x/genkit_test.go b/go/genkit/exp/flow_test.go similarity index 99% rename from go/genkit/x/genkit_test.go rename to go/genkit/exp/flow_test.go index 81b4d15105..cde23419e3 100644 --- a/go/genkit/x/genkit_test.go +++ b/go/genkit/exp/flow_test.go @@ -14,7 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -package x +package exp import ( "context" diff --git a/go/genkit/exp/routes.go b/go/genkit/exp/routes.go new file mode 100644 index 0000000000..36c268e173 --- /dev/null +++ b/go/genkit/exp/routes.go @@ -0,0 +1,164 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "log/slog" + "net/http" + + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" +) + +// Base paths for the built-in serving layouts. +const ( + agentBasePath = "/agents" + flowBasePath = "/flows" +) + +// Route is one HTTP route in a primitive's default serving layout: the +// method and path to mount and the action to serve. +// +// [AgentRoutes], [AllAgentRoutes], [FlowRoutes], and [AllFlowRoutes] +// produce Routes; [Mount] wires them onto an [http.ServeMux]. The fields +// are exported so other routers (Gin, Chi, Echo) can mount the same +// layout: read Method and Path and serve Action with [genkit.Handler]. +// Every route is a POST that speaks the {"data": ...} / {"result": ...} +// envelope of the reflection API (the agent turn route also streams via +// ?stream=true), so a single client transport reaches all of them. +type Route struct { + // Method is the HTTP method; always "POST" for the built-in layouts. + Method string + // Path is the URL path to mount, e.g. "/agents/chat/getSnapshot". + Path string + // Action is the action served at this route via [genkit.Handler]. + Action api.Action +} + +// Pattern returns the "METHOD /path" pattern for [http.ServeMux.HandleFunc]. +func (r Route) Pattern() string { + return r.Method + " " + r.Path +} + +// Handler builds the HTTP handler for this route with [genkit.Handler], +// applying opts. +func (r Route) Handler(opts ...genkit.HandlerOption) http.HandlerFunc { + return genkit.Handler(r.Action, opts...) +} + +// AllAgentRoutes returns the default serving layout for every agent +// registered with g, the iterate-over-all counterpart to [AgentRoutes]. +// Pass the result to [Mount], or to a router of your choice. See +// [AgentRoutes] for the per-agent layout and the route set each agent +// contributes. +func AllAgentRoutes(g *genkit.Genkit) []Route { + var routes []Route + for _, act := range genkit.ListAgents(g) { + name := act.Name() + // The snapshot-lifecycle companions register independently under + // their own action types, keyed by the agent's name (see the agent + // package's snapshot companions). Recover them by key so the layout + // depends only on the registry, not on the concrete agent type. + // LookupAction returns nil for a capability the agent lacks (a + // client-managed agent has no companions), and buildAgentRoutes omits + // the route for a nil companion. + snapshot := genkit.LookupAction(g, api.KeyFromName(api.ActionTypeAgentSnapshot, name)) + abort := genkit.LookupAction(g, api.KeyFromName(api.ActionTypeAgentAbort, name)) + routes = append(routes, buildAgentRoutes(name, act, snapshot, abort)...) + } + return routes +} + +// AgentRoutes returns the default serving layout for a single agent, so you +// can mount specific agents rather than every registered one. Pass the +// result to [Mount], or to a router of your choice. +// +// The route set mirrors what the agent can do: +// +// - POST /agents/{name} the agent, one turn per request +// - POST /agents/{name}/getSnapshot getSnapshot (store-backed agents) +// - POST /agents/{name}/abortSnapshot abortSnapshot (abortable stores) +// +// Each takes the {"data": ...} request envelope and returns {"result": +// ...}; the snapshot ID rides in the body ({"data": {"snapshotId": ...}}), +// the same as the reflection API. Companion routes are omitted for +// capabilities the agent lacks; a client-managed agent contributes only +// its turn route. +func AgentRoutes[Stream, State any](a *aix.Agent[Stream, State]) []Route { + return buildAgentRoutes(a.Name(), a, a.GetSnapshotAction(), a.AbortSnapshotAction()) +} + +// buildAgentRoutes builds an agent's route set from its run action and the +// companion actions it has (either may be nil). Shared by AllAgentRoutes +// (companions looked up by key) and AgentRoutes (companions from the typed +// ref's accessors). +func buildAgentRoutes(name string, run, snapshot, abort api.Action) []Route { + routes := []Route{{Method: http.MethodPost, Path: agentBasePath + "/" + name, Action: run}} + if snapshot != nil { + routes = append(routes, Route{ + Method: http.MethodPost, + Path: agentBasePath + "/" + name + "/getSnapshot", + Action: snapshot, + }) + } + if abort != nil { + routes = append(routes, Route{ + Method: http.MethodPost, + Path: agentBasePath + "/" + name + "/abortSnapshot", + Action: abort, + }) + } + return routes +} + +// AllFlowRoutes returns the default serving layout for every flow +// registered with g, the iterate-over-all counterpart to [FlowRoutes]. +// Pass the result to [Mount], or to a router of your choice. +func AllFlowRoutes(g *genkit.Genkit) []Route { + var routes []Route + for _, f := range genkit.ListFlows(g) { + routes = append(routes, buildFlowRoute(f)) + } + return routes +} + +// FlowRoutes returns the default serving layout for a single flow: one +// route, POST /flows/{name}, taking its input from the request body. It +// returns a slice for symmetry with [AgentRoutes], so route lists compose +// with append. +func FlowRoutes(f api.Action) []Route { + return []Route{buildFlowRoute(f)} +} + +func buildFlowRoute(f api.Action) Route { + return Route{Method: http.MethodPost, Path: flowBasePath + "/" + f.Name(), Action: f} +} + +// Mount registers routes on mux, building each route's handler with opts +// (e.g. [genkit.WithContextProviders] shared across all of them). It is +// the stdlib convenience over [Route.Handler]; routers other than +// [http.ServeMux] can range over the routes and mount them directly. +// +// Compose layouts by concatenating: Mount(mux, append(AllAgentRoutes(g), +// AllFlowRoutes(g)...), opts...). +func Mount(mux *http.ServeMux, routes []Route, opts ...genkit.HandlerOption) { + for _, rt := range routes { + mux.HandleFunc(rt.Pattern(), rt.Handler(opts...)) + slog.Debug("genkit/exp.Mount", "method", rt.Method, "path", rt.Path) + } +} diff --git a/go/genkit/exp/routes_test.go b/go/genkit/exp/routes_test.go new file mode 100644 index 0000000000..7d19a53a3b --- /dev/null +++ b/go/genkit/exp/routes_test.go @@ -0,0 +1,224 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "slices" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" + "github.com/firebase/genkit/go/genkit" +) + +// routeKey is a compact "METHOD path" identity for asserting on a route set +// without depending on order. +func routeKey(r Route) string { return r.Method + " " + r.Path } + +func routeKeys(routes []Route) []string { + keys := make([]string, len(routes)) + for i, r := range routes { + keys[i] = routeKey(r) + } + slices.Sort(keys) + return keys +} + +// newRouteTestGenkit defines a server-managed agent (with abortable store), +// a client-managed agent, and a flow, the mix the route builders must +// distinguish. +func newRouteTestGenkit(t *testing.T) *genkit.Genkit { + t.Helper() + g := genkit.Init(context.Background()) + + genkit.DefineModel(g, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return &ai.ModelResponse{ + Message: ai.NewModelTextMessage(fmt.Sprintf("echo %d", len(req.Messages))), + FinishReason: ai.FinishReasonStop, + }, nil + }) + + store, err := localstore.NewFileSessionStore[any](t.TempDir()) + if err != nil { + t.Fatal(err) + } + genkit.DefineAgent(g, "serverChat", aix.FromInline(ai.WithModelName("test/echo")), + aix.WithSessionStore(store), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ) + genkit.DefineAgent[any](g, "clientChat", aix.FromInline(ai.WithModelName("test/echo"))) + genkit.DefineFlow(g, "greet", func(ctx context.Context, name string) (string, error) { + return "hi " + name, nil + }) + + return g +} + +func TestAllAgentRoutes(t *testing.T) { + g := newRouteTestGenkit(t) + + got := routeKeys(AllAgentRoutes(g)) + want := []string{ + "POST /agents/clientChat", + "POST /agents/serverChat", + "POST /agents/serverChat/abortSnapshot", + "POST /agents/serverChat/getSnapshot", + } + if !slices.Equal(got, want) { + t.Errorf("AllAgentRoutes layout =\n %v\nwant\n %v", got, want) + } + + // Every route is a POST carrying a non-nil action; companions are plain + // subpaths served with the same enveloped Handler as the turn route. + for _, r := range AllAgentRoutes(g) { + if r.Action == nil { + t.Errorf("route %q has nil Action", routeKey(r)) + } + if r.Method != http.MethodPost { + t.Errorf("route %q method = %q, want POST", routeKey(r), r.Method) + } + } +} + +func TestAgentRoutes_PicksOneAgentAndMirrorsCapabilities(t *testing.T) { + g := genkit.Init(context.Background()) + genkit.DefineModel(g, "test/echo", &ai.ModelOptions{}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return &ai.ModelResponse{Message: ai.NewModelTextMessage("ok"), FinishReason: ai.FinishReasonStop}, nil + }) + store, err := localstore.NewFileSessionStore[any](t.TempDir()) + if err != nil { + t.Fatal(err) + } + server := genkit.DefineAgent(g, "srv", aix.FromInline(ai.WithModelName("test/echo")), aix.WithSessionStore(store)) + client := genkit.DefineAgent[any](g, "cli", aix.FromInline(ai.WithModelName("test/echo"))) + + if got, want := routeKeys(AgentRoutes(server)), []string{ + "POST /agents/srv", + "POST /agents/srv/abortSnapshot", + "POST /agents/srv/getSnapshot", + }; !slices.Equal(got, want) { + t.Errorf("AgentRoutes(server) = %v, want %v", got, want) + } + + // Client-managed: just the turn route, no companions. + if got, want := routeKeys(AgentRoutes(client)), []string{"POST /agents/cli"}; !slices.Equal(got, want) { + t.Errorf("AgentRoutes(client) = %v, want %v", got, want) + } +} + +func TestAllFlowRoutes(t *testing.T) { + g := newRouteTestGenkit(t) + + got := routeKeys(AllFlowRoutes(g)) + if want := []string{"POST /flows/greet"}; !slices.Equal(got, want) { + t.Errorf("AllFlowRoutes = %v, want %v", got, want) + } + + greet := genkit.NewFlow("standalone", func(ctx context.Context, s string) (string, error) { return s, nil }) + single := FlowRoutes(greet) + if len(single) != 1 || routeKey(single[0]) != "POST /flows/standalone" { + t.Errorf("FlowRoutes = %v, want one POST /flows/standalone", routeKeys(single)) + } +} + +// TestMount exercises the full path: build the all-agents layout, mount it +// on a ServeMux, and drive the resulting endpoints. It proves every route +// speaks the same enveloped Handler transport (the turn and the getSnapshot +// companion alike) and that a client-managed agent has only its turn route. +func TestMount(t *testing.T) { + g := newRouteTestGenkit(t) + + mux := http.NewServeMux() + Mount(mux, AllAgentRoutes(g)) + + do := func(t *testing.T, method, path, body string) (int, string) { + t.Helper() + var rdr io.Reader + if body != "" { + rdr = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, rdr) + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + b, _ := io.ReadAll(w.Result().Body) + return w.Result().StatusCode, string(b) + } + + // A turn on the server-managed agent goes through the enveloped Handler + // transport and yields a snapshot. + code, body := do(t, "POST", "/agents/serverChat", + `{"data":{"message":{"role":"user","content":[{"text":"hi"}]}}}`) + if code != http.StatusOK { + t.Fatalf("turn status = %d, body = %s", code, body) + } + var env struct { + Result struct { + SnapshotID string `json:"snapshotId"` + } `json:"result"` + } + if err := json.Unmarshal([]byte(body), &env); err != nil { + t.Fatalf("turn not enveloped as expected: %v; body = %s", err, body) + } + if env.Result.SnapshotID == "" { + t.Fatalf("no snapshotId from turn; body = %s", body) + } + + // The mounted getSnapshot route is a POST taking the snapshot ID in the + // {"data": ...} body and returning the snapshot in the {"result": ...} + // envelope, exactly like the turn route. + code, body = do(t, "POST", "/agents/serverChat/getSnapshot", + `{"data":{"snapshotId":"`+env.Result.SnapshotID+`"}}`) + if code != http.StatusOK { + t.Fatalf("getSnapshot status = %d, body = %s", code, body) + } + var snapEnv struct { + Result struct { + SnapshotID string `json:"snapshotId"` + } `json:"result"` + } + if err := json.Unmarshal([]byte(body), &snapEnv); err != nil { + t.Fatalf("getSnapshot not enveloped as expected: %v; body = %s", err, body) + } + if snapEnv.Result.SnapshotID != env.Result.SnapshotID { + t.Errorf("snapshot id = %q, want %q", snapEnv.Result.SnapshotID, env.Result.SnapshotID) + } + + // The client-managed agent is reachable... + code, body = do(t, "POST", "/agents/clientChat", + `{"data":{"message":{"role":"user","content":[{"text":"hi"}]}}}`) + if code != http.StatusOK { + t.Fatalf("client turn status = %d, body = %s", code, body) + } + // ...but has no companion route mounted. + code, _ = do(t, "POST", "/agents/clientChat/getSnapshot", `{"data":{"snapshotId":"whatever"}}`) + if code != http.StatusNotFound { + t.Errorf("client-managed agent should have no getSnapshot route; status = %d, want 404", code) + } +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 83fe816fa1..df395a4bb4 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -304,6 +304,19 @@ func RegisterAction(g *Genkit, action api.Registerable) { action.Register(g.reg) } +// LookupAction returns the action registered with g under key, or nil if +// none is registered. key is an action's fully qualified +// "/type/provider/name" identifier; build it with [api.NewKey] or +// [api.KeyFromName]. For example, an agent's getSnapshot companion is keyed +// by api.KeyFromName(api.ActionTypeAgentSnapshot, agentName). +// +// This is the generic, type-agnostic lookup. Prefer a typed accessor +// ([LookupModel], [LookupPrompt], etc.) when one exists for the kind of +// action you need. +func LookupAction(g *Genkit, key string) api.Action { + return g.reg.LookupAction(key) +} + // DefineFlow defines a non-streaming flow, registers it as a [core.Action] of type Flow, // and returns a [core.Flow] runner. // The provided function `fn` takes an input of type `In` and returns an output of type `Out`. @@ -417,6 +430,12 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // (e.g. [aix.WithSessionStore], [aix.WithSnapshotOn]); pass an explicit // [State] only when no typed option is provided. // +// The returned agent is an [api.BidiAction]; pass it to [Handler] to +// serve it over HTTP, one turn per request. Server-managed agents also +// register companion actions for the snapshot lifecycle; serve them +// alongside the agent via [aix.Agent.GetSnapshotAction] and +// [aix.Agent.AbortSnapshotAction]. +// // For full control over the per-turn loop, use [DefineCustomAgent]. // // # Options @@ -467,6 +486,10 @@ func DefineAgent[State any]( // Call [aix.SessionRunner.Run] to enter the turn loop, which blocks until the // client sends the next message. // +// Like [DefineAgent], the returned agent is an [api.BidiAction] servable +// via [Handler], with companion actions on [aix.Agent.GetSnapshotAction] +// and [aix.Agent.AbortSnapshotAction]. +// // For agents backed by a prompt, use [DefineAgent] with [aix.FromInline] // (inline prompt) or [aix.FromPrompt] (existing prompt) instead. // diff --git a/go/genkit/servers.go b/go/genkit/servers.go index 288baa95ac..1eb34f4a4c 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -196,30 +196,9 @@ func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http } stream = stream || r.Header.Get("Accept") == "text/event-stream" - ctx := r.Context() - if opts.ContextProviders != nil { - for _, ctxProvider := range opts.ContextProviders { - headers := make(map[string]string, len(r.Header)) - for k, v := range r.Header { - headers[strings.ToLower(k)] = strings.Join(v, " ") - } - - actionCtx, err := ctxProvider(ctx, core.RequestData{ - Method: r.Method, - Headers: headers, - Input: body.Data, - }) - if err != nil { - logger.FromContext(ctx).Error("error providing action context from request", "err", err) - return err - } - - if existing := core.FromContext(ctx); existing != nil { - maps.Copy(existing, actionCtx) - actionCtx = existing - } - ctx = core.WithActionContext(ctx, actionCtx) - } + ctx, err := applyContextProviders(r.Context(), r, opts.ContextProviders, body.Data) + if err != nil { + return err } if stream { @@ -250,6 +229,38 @@ func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http } } +// applyContextProviders runs the configured context providers against the +// request and folds their results into ctx, so request-derived action +// context (e.g. auth from headers) is available to the action. input is +// handed to each provider as the request's decoded input +// ([core.RequestData.Input]). A nil or empty providers slice returns ctx +// unchanged. +func applyContextProviders(ctx context.Context, r *http.Request, providers []core.ContextProvider, input json.RawMessage) (context.Context, error) { + for _, ctxProvider := range providers { + headers := make(map[string]string, len(r.Header)) + for k, v := range r.Header { + headers[strings.ToLower(k)] = strings.Join(v, " ") + } + + actionCtx, err := ctxProvider(ctx, core.RequestData{ + Method: r.Method, + Headers: headers, + Input: input, + }) + if err != nil { + logger.FromContext(ctx).Error("error providing action context from request", "err", err) + return ctx, err + } + + if existing := core.FromContext(ctx); existing != nil { + maps.Copy(existing, actionCtx) + actionCtx = existing + } + ctx = core.WithActionContext(ctx, actionCtx) + } + return ctx, nil +} + // runJSONFunc abstracts over RunJSON and RunBidiJSON for the handler's // execution paths. type runJSONFunc = func(context.Context, json.RawMessage, func(context.Context, json.RawMessage) error) (json.RawMessage, error) diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 7b9e13b537..fb6d325c38 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -26,6 +26,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/firebase/genkit/go/ai" aix "github.com/firebase/genkit/go/ai/exp" @@ -729,8 +730,8 @@ type agentHTTPResult struct { } `json:"error"` } -// TestHandlerAgent verifies that agents, being bidi actions of type flow, -// serve one-turn-at-a-time over the standard action handler: data carries +// TestHandlerAgent verifies that agents, being bidi actions, serve +// one-turn-at-a-time over the standard action handler: data carries // the turn's AgentInput, init carries the session source (state for // client-managed agents, sessionId/snapshotId for server-managed ones), and // the conversation resumes across requests. It also pins the error contract: @@ -946,3 +947,195 @@ func TestHandlerAgent(t *testing.T) { } }) } + +// TestHandlerAgentRef verifies that the typed agent ref is servable +// directly: the ref satisfies api.BidiAction, so Handler routes init +// through the bidi interface (session resume keeps working), and the +// companion actions plucked off the ref (GetSnapshotAction, +// AbortSnapshotAction) serve the snapshot lifecycle on caller-chosen +// routes. Together they pin the detach → poll → abort story over plain +// HTTP. +func TestHandlerAgentRef(t *testing.T) { + g := Init(context.Background()) + + DefineModel(g, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return &ai.ModelResponse{ + Message: ai.NewModelTextMessage(fmt.Sprintf("echo %d", len(req.Messages))), + FinishReason: ai.FinishReasonStop, + }, nil + }) + + store, err := localstore.NewFileSessionStore[any](t.TempDir()) + if err != nil { + t.Fatal(err) + } + agent := DefineAgent(g, "agentRef", aix.FromInline(ai.WithModelName("test/echo")), + aix.WithSessionStore(store), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ) + + // Handlers come straight off the ref; no registry iteration involved. + runHandler := Handler(agent) + getSnapshotHandler := Handler(agent.GetSnapshotAction()) + abortHandler := Handler(agent.AbortSnapshotAction()) + + post := func(t *testing.T, h http.HandlerFunc, body string) (int, string) { + t.Helper() + req := httptest.NewRequest("POST", "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + h(w, req) + respBody, _ := io.ReadAll(w.Result().Body) + return w.Result().StatusCode, string(respBody) + } + + parseResult := func(t *testing.T, body string) agentHTTPResult { + t.Helper() + var envelope struct { + Result agentHTTPResult `json:"result"` + } + if err := json.Unmarshal([]byte(body), &envelope); err != nil { + t.Fatalf("unmarshal %q: %v", body, err) + } + return envelope.Result + } + + // snapshotHTTPResult covers both companion responses: getSnapshot + // returns the snapshot row, abortSnapshot echoes {snapshotId, status}. + type snapshotHTTPResult struct { + SnapshotID string `json:"snapshotId"` + SessionID string `json:"sessionId"` + Status string `json:"status"` + State json.RawMessage `json:"state"` + } + parseSnapshot := func(t *testing.T, body string) snapshotHTTPResult { + t.Helper() + var envelope struct { + Result snapshotHTTPResult `json:"result"` + } + if err := json.Unmarshal([]byte(body), &envelope); err != nil { + t.Fatalf("unmarshal %q: %v", body, err) + } + return envelope.Result + } + + turn := func(text string) string { + return `{"data":{"message":{"role":"user","content":[{"text":"` + text + `"}]}}}` + } + + t.Run("turns resume through the ref's bidi interface", func(t *testing.T) { + code, body := post(t, runHandler, turn("hello")) + if code != http.StatusOK { + t.Fatalf("status = %d, body = %s", code, body) + } + res := parseResult(t, body) + if res.SessionID == "" || res.SnapshotID == "" { + t.Fatalf("missing session/snapshot ID: %+v", res) + } + + // Resume rides the init field, which the handler only accepts + // because the ref satisfies api.BidiAction. + code, body = post(t, runHandler, + `{"data":{"message":{"role":"user","content":[{"text":"again"}]}},"init":{"sessionId":"`+res.SessionID+`"}}`) + if code != http.StatusOK { + t.Fatalf("resume status = %d, body = %s", code, body) + } + if got := parseResult(t, body).Message.Text(); got != "echo 3" { + t.Errorf("resumed message = %q, want %q", got, "echo 3") + } + }) + + t.Run("getSnapshot serves the persisted snapshot", func(t *testing.T) { + code, body := post(t, runHandler, turn("snapshot me")) + if code != http.StatusOK { + t.Fatalf("turn status = %d, body = %s", code, body) + } + res := parseResult(t, body) + + code, body = post(t, getSnapshotHandler, `{"data":{"snapshotId":"`+res.SnapshotID+`"}}`) + if code != http.StatusOK { + t.Fatalf("getSnapshot status = %d, body = %s", code, body) + } + snap := parseSnapshot(t, body) + if snap.SnapshotID != res.SnapshotID || snap.SessionID != res.SessionID { + t.Errorf("snapshot identity = %q/%q, want %q/%q", snap.SnapshotID, snap.SessionID, res.SnapshotID, res.SessionID) + } + // The action normalizes the implicit empty status to "succeeded" + // so remote clients don't reimplement the default. + if snap.Status != "succeeded" { + t.Errorf("status = %q, want %q", snap.Status, "succeeded") + } + if len(snap.State) == 0 { + t.Error("snapshot must carry state") + } + }) + + t.Run("unknown snapshot IDs map to 404", func(t *testing.T) { + code, body := post(t, getSnapshotHandler, `{"data":{"snapshotId":"nope"}}`) + if code != http.StatusNotFound { + t.Errorf("getSnapshot: status = %d, want %d; body = %s", code, http.StatusNotFound, body) + } + code, body = post(t, abortHandler, `{"data":{"snapshotId":"nope"}}`) + if code != http.StatusNotFound { + t.Errorf("abortSnapshot: status = %d, want %d; body = %s", code, http.StatusNotFound, body) + } + }) + + t.Run("abort on a terminal snapshot echoes its status", func(t *testing.T) { + code, body := post(t, runHandler, turn("terminal")) + if code != http.StatusOK { + t.Fatalf("turn status = %d, body = %s", code, body) + } + res := parseResult(t, body) + + code, body = post(t, abortHandler, `{"data":{"snapshotId":"`+res.SnapshotID+`"}}`) + if code != http.StatusOK { + t.Fatalf("abortSnapshot status = %d, body = %s", code, body) + } + snap := parseSnapshot(t, body) + if snap.Status != "succeeded" { + t.Errorf("status = %q, want %q (abort of a terminal snapshot is a no-op)", snap.Status, "succeeded") + } + }) + + t.Run("detached turn finalizes in the background", func(t *testing.T) { + code, body := post(t, runHandler, + `{"data":{"detach":true,"message":{"role":"user","content":[{"text":"work in background"}]}}}`) + if code != http.StatusOK { + t.Fatalf("detach status = %d, body = %s", code, body) + } + res := parseResult(t, body) + if res.FinishReason != "detached" { + t.Fatalf("finishReason = %q, want %q; body = %s", res.FinishReason, "detached", body) + } + if res.SnapshotID == "" { + t.Fatal("detached output missing the pending snapshotId") + } + + // Poll the companion route until the background turn finalizes the + // pending row, the way a remote client would. + deadline := time.After(10 * time.Second) + for { + code, body := post(t, getSnapshotHandler, `{"data":{"snapshotId":"`+res.SnapshotID+`"}}`) + if code != http.StatusOK { + t.Fatalf("getSnapshot status = %d, body = %s", code, body) + } + snap := parseSnapshot(t, body) + if snap.Status != "pending" { + if snap.Status != "succeeded" { + t.Fatalf("final status = %q, want %q; body = %s", snap.Status, "succeeded", body) + } + if len(snap.State) == 0 { + t.Error("finalized snapshot must carry the cumulative state") + } + break + } + select { + case <-deadline: + t.Fatal("snapshot still pending after 10s") + case <-time.After(10 * time.Millisecond): + } + } + }) +} diff --git a/go/samples/basic-agents-server/main.go b/go/samples/basic-agents-server/main.go index b6dd06e9f7..c340bbb70f 100644 --- a/go/samples/basic-agents-server/main.go +++ b/go/samples/basic-agents-server/main.go @@ -25,7 +25,9 @@ // persists a snapshot; the response carries sessionId and snapshotId, // and a later request resumes the conversation by sending // {"init": {"sessionId": ...}} (or {"snapshotId": ...} to resume from -// a specific point in history). +// a specific point in history). The store also gives the agent +// snapshot companion actions, served here at +// /agents/chat/getSnapshot and /agents/chat/abortSnapshot. // - "statelessChat" has no store (client-managed state). The response // carries the full conversation state; the client sends it back // verbatim as {"init": {"state": ...}} on the next turn. The server @@ -37,28 +39,55 @@ // // Start a conversation (no init starts a fresh session): // -// curl -X POST http://localhost:8080/chat \ +// curl -X POST http://localhost:8080/agents/chat \ // -H "Content-Type: application/json" \ // -d '{"data": {"message": {"role": "user", "content": [{"text": "My name is Alex and I am planning a trip to Japan."}]}}}' // // Continue it, using the sessionId from the response: // -// curl -X POST http://localhost:8080/chat \ +// curl -X POST http://localhost:8080/agents/chat \ // -H "Content-Type: application/json" \ // -d '{"data": {"message": {"role": "user", "content": [{"text": "What is my name?"}]}}, "init": {"sessionId": "SESSION_ID"}}' // // Stream a turn's model chunks and lifecycle events as server-sent events: // -// curl -N -X POST 'http://localhost:8080/chat?stream=true' \ +// curl -N -X POST 'http://localhost:8080/agents/chat?stream=true' \ // -H "Content-Type: application/json" \ // -d '{"data": {"message": {"role": "user", "content": [{"text": "Suggest three day trips from Tokyo."}]}}}' // // For statelessChat, resume by round-tripping the returned state instead: // -// curl -X POST http://localhost:8080/statelessChat \ +// curl -X POST http://localhost:8080/agents/statelessChat \ // -H "Content-Type: application/json" \ // -d '{"data": {"message": {"role": "user", "content": [{"text": "What is my name?"}]}}, "init": {"state": STATE_FROM_PREVIOUS_RESPONSE}}' // +// Server-managed state also unlocks background continuation. Send a turn +// with "detach": true and the response comes back immediately with +// finishReason "detached" and a pending snapshotId, while the turn keeps +// running on the server: +// +// curl -X POST http://localhost:8080/agents/chat \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"message": {"role": "user", "content": [{"text": "Plan a two-week Japan itinerary."}]}, "detach": true}}' +// +// The companion endpoints follow the conversation from there. Each is a +// POST that carries the snapshotId in the {"data": ...} body and returns +// the {"result": ...} envelope, the same convention as the turn route. +// Poll the pending snapshot until its status leaves "pending" (the final +// state carries the result), using the snapshotId from the detach +// response: +// +// curl -X POST http://localhost:8080/agents/chat/getSnapshot \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"snapshotId": "SNAPSHOT_ID"}}' +// +// Or abort the background work instead; an aborted snapshot finalizes +// with status "aborted": +// +// curl -X POST http://localhost:8080/agents/chat/abortSnapshot \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"snapshotId": "SNAPSHOT_ID"}}' +// // Failures come in two tiers. A failed turn (e.g. the model call errors) // still returns HTTP 200: the result reports finishReason "failed", a // structured error ({status, message, details}), and the last-good @@ -77,6 +106,7 @@ import ( aix "github.com/firebase/genkit/go/ai/exp" "github.com/firebase/genkit/go/ai/exp/localstore" "github.com/firebase/genkit/go/genkit" + genkitx "github.com/firebase/genkit/go/genkit/exp" "github.com/firebase/genkit/go/plugins/googlegenai" "github.com/firebase/genkit/go/plugins/server" "google.golang.org/genai" @@ -122,12 +152,28 @@ func main() { ), ) - // Agents register under their own "agent" action type; ListAgents - // surfaces them and the standard action handler serves them one turn - // per request. + // genkitx.AllAgentRoutes lays out a default HTTP surface for every + // registered agent, and genkitx.Mount wires it onto the mux. The layout + // follows each agent's capabilities, so server-managed and + // client-managed agents can be deployed side by side from one call: + // + // "chat" (store-backed): + // POST /agents/chat one turn per request + // POST /agents/chat/getSnapshot read a snapshot by ID + // POST /agents/chat/abortSnapshot abort background work + // "statelessChat" (client-managed): + // POST /agents/statelessChat one turn per request + // + // Every route is a POST taking the standard {"data": ...} envelope and + // returning {"result": ...}; the companions read the snapshotId from + // that body. HandlerOptions passed here (e.g. context providers for + // auth) apply to every route. + // + // To serve specific agents instead of all of them, use + // genkitx.AgentRoutes(agent); to expose flows, genkitx.AllFlowRoutes(g). + // Mix them by concatenating the route slices. The genkitx (genkit/exp) + // package holds these helpers while the routing layer is experimental. mux := http.NewServeMux() - for _, a := range genkit.ListAgents(g) { - mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) - } + genkitx.Mount(mux, genkitx.AllAgentRoutes(g)) log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) } diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index 4f403a5ea0..c8a7defa6f 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -46,25 +46,8 @@ import ( "github.com/firebase/genkit/go/ai" aix "github.com/firebase/genkit/go/ai/exp" - "github.com/firebase/genkit/go/ai/exp/localstore" ) -// sampleAgent pairs an agent with the store it persists to and a -// one-line description for the CLI list view. The embedded -// *aix.Agent[any, any] makes Name(), StreamBidi() etc. callable -// directly on a sampleAgent value, so the CLI does not need a -// separate field-threading layer. -// -// Store is tracked alongside the agent (rather than fished out of it) -// because we use FileSessionStore-specific helpers like -// LatestSnapshot and OnSnapshotStatusChange; carrying the concrete -// type avoids a type assertion at every call site. -type sampleAgent struct { - *aix.Agent[any, any] - Store *localstore.FileSessionStore[any] - Description string -} - // errQuit signals that the user typed /quit somewhere in the CLI; it // bubbles up through openAgent and breaks runCLI's outer loop. var errQuit = errors.New("quit") @@ -73,7 +56,7 @@ var errQuit = errors.New("quit") // between two screens forever: the agent list and a per-agent chat. // Returning from a chat brings the user back to the agent list. /quit // (anywhere) and Ctrl-C both unwind back here and exit cleanly. -func runCLI(ctx context.Context, agents []sampleAgent) error { +func runCLI(ctx context.Context, agents []*aix.Agent[any, any]) error { fmt.Println("Genkit Basic Agents") fmt.Println("===================") fmt.Println() @@ -89,31 +72,42 @@ func runCLI(ctx context.Context, agents []sampleAgent) error { fmt.Println(" /back return to the agent list") fmt.Println(" /quit exit the program") + // lastSession remembers, per agent, the session ID of the most recent + // conversation this process ran with it. That is all a client needs to + // resume: re-entering an agent resolves the session's latest snapshot + // from the store (see SnapshotReader.GetLatestSnapshot). It lives in + // memory, so a fresh run starts clean rather than rediscovering prior + // conversations from the store. + lastSession := map[string]string{} + inputCh := readLines(ctx) for { - choice, ok := pickAgent(ctx, inputCh, agents) + choice, ok := pickAgent(ctx, inputCh, agents, lastSession) if !ok { return nil } - if err := openAgent(ctx, inputCh, agents[choice]); err != nil { + a := agents[choice] + sessionID, err := openAgent(ctx, inputCh, a, lastSession[a.Name()]) + if err != nil { if errors.Is(err, errQuit) { return nil } return err } + lastSession[a.Name()] = sessionID } } // pickAgent renders the agent list and reads the user's choice. The // list is re-rendered between selections so the user can see updated // pending/terminal status after returning from a chat. -func pickAgent(ctx context.Context, inputCh <-chan string, agents []sampleAgent) (int, bool) { +func pickAgent(ctx context.Context, inputCh <-chan string, agents []*aix.Agent[any, any], lastSession map[string]string) (int, bool) { for { fmt.Println() fmt.Println("Agents:") for i, a := range agents { - fmt.Printf(" %d. %s — %s\n", i+1, a.Name(), a.Description) - if summary := summarizeLatest(ctx, a); summary != "" { + fmt.Printf(" %d. %s — %s\n", i+1, a.Name(), a.Desc().Description) + if summary := summarizeLatest(ctx, a, lastSession[a.Name()]); summary != "" { fmt.Printf(" last: %s\n", summary) } } @@ -147,29 +141,37 @@ func pickAgent(ctx context.Context, inputCh <-chan string, agents []sampleAgent) // so the rest of the flow is uniform: ok=false means the user backed // out, otherwise hand the chosen snapshot (or nil for fresh) to // runChat. -func openAgent(ctx context.Context, inputCh <-chan string, a sampleAgent) error { - latest, err := a.Store.LatestSnapshot(ctx) - if err != nil { - return fmt.Errorf("read snapshots for %q: %w", a.Name(), err) +func openAgent(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], lastSessionID string) (string, error) { + // Resolve where the last conversation left off. With no tracked + // session (a first visit this run) there is nothing to resume; + // otherwise the store resolves the session's latest snapshot, which + // also surfaces a still-pending detached invocation. + var tip *aix.SessionSnapshot[any] + if lastSessionID != "" { + var err error + tip, err = a.Store().GetLatestSnapshot(ctx, lastSessionID) + if err != nil { + return lastSessionID, fmt.Errorf("read snapshot for %q: %w", a.Name(), err) + } } var ( resume *aix.SessionSnapshot[any] ok bool ) - if latest != nil && latest.Status == aix.SnapshotStatusPending { - // Background invocation still in flight. handlePending makes - // the final decision itself (wait & resume, new, or back), so - // we don't fall through to pickSession — the user already - // chose; asking again would just be noise. - resume, ok = handlePending(ctx, inputCh, a, latest) + if tip != nil && tip.Status == aix.SnapshotStatusPending { + // Background invocation still in flight. handlePending makes the + // final decision itself (wait & resume, new, or back), so we don't + // fall through to pickSession: the user already chose, and asking + // again would just be noise. + resume, ok = handlePending(ctx, inputCh, a, tip) } else { - resume, ok = pickSession(ctx, inputCh, a, latest) + resume, ok = pickSession(ctx, inputCh, a, tip) } if !ok { - return nil + return lastSessionID, nil } - return runChat(ctx, inputCh, a, resume) + return runChat(ctx, inputCh, a, resume, lastSessionID) } // handlePending offers the three reasonable responses when a previous @@ -187,7 +189,7 @@ func openAgent(ctx context.Context, inputCh <-chan string, a sampleAgent) error // snapshot directly so the caller can skip the resume / new prompt: // the user already committed to the choice by waiting, and re-asking // would be redundant. -func handlePending(ctx context.Context, inputCh <-chan string, a sampleAgent, pending *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { +func handlePending(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], pending *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { for { fmt.Printf("\nThe last %s session is still running in the background (%s).\n", a.Name(), shortID(pending.SnapshotID)) fmt.Println(" 1. Wait for it to finalize") @@ -202,7 +204,7 @@ func handlePending(ctx context.Context, inputCh <-chan string, a sampleAgent, pe switch strings.TrimSpace(line) { case "1": fmt.Println("Waiting for it to finalize...") - final, err := waitForFinalize(ctx, a.Store, pending.SnapshotID) + final, err := waitForFinalize(ctx, a, pending.SnapshotID) if err != nil { fmt.Fprintf(os.Stderr, "Wait error: %v\n", err) return nil, false @@ -237,7 +239,7 @@ func handlePending(ctx context.Context, inputCh <-chan string, a sampleAgent, pe // offers two paths so the demo stays focused: resume from the most // recent terminal snapshot (returns the snapshot pointer), or start // fresh (returns nil). -func pickSession(ctx context.Context, inputCh <-chan string, a sampleAgent, latest *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { +func pickSession(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], latest *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { if latest == nil || latest.Status != aix.SnapshotStatusSucceeded { fmt.Printf("\nStarting a new conversation with %s.\n", a.Name()) return nil, true @@ -272,9 +274,9 @@ func pickSession(ctx context.Context, inputCh <-chan string, a sampleAgent, late // snapshot the user was just shown. Validating up front keeps the chat // from opening on a connection whose invocation already failed, which // would surface the error only after the user types a message. -func resumeOption(ctx context.Context, a sampleAgent, resume *aix.SessionSnapshot[any]) aix.InvocationOption[any] { +func resumeOption(ctx context.Context, a *aix.Agent[any, any], resume *aix.SessionSnapshot[any]) aix.InvocationOption[any] { if resume.SessionID != "" { - tip, err := a.Store.GetLatestSnapshot(ctx, resume.SessionID) + tip, err := a.Store().GetLatestSnapshot(ctx, resume.SessionID) if err == nil && tip != nil && tip.Status != aix.SnapshotStatusPending { return aix.WithSessionID[any](resume.SessionID) } @@ -287,10 +289,11 @@ func resumeOption(ctx context.Context, a sampleAgent, resume *aix.SessionSnapsho // snapshot) and runs the per-turn REPL. When resuming, the prior // conversation is replayed first so the user sees the context they're // picking up, then the REPL takes over. /detach is the one interesting -// branch — it sends the optional trailing text as the final input and -// detaches the connection, returning the pending snapshot ID for the -// user to observe. -func runChat(ctx context.Context, inputCh <-chan string, a sampleAgent, resume *aix.SessionSnapshot[any]) error { +// branch: it sends the optional trailing text as the final input and +// detaches the connection, leaving a pending snapshot for the user to +// observe. It returns the session ID the chat ran under (falling back to +// prevSessionID) so the caller can offer to resume it later. +func runChat(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], resume *aix.SessionSnapshot[any], prevSessionID string) (string, error) { fmt.Printf("\n=== Chatting with %s ===\n", a.Name()) if resume != nil { fmt.Printf("Resumed from %s\n", shortID(resume.SnapshotID)) @@ -310,7 +313,7 @@ func runChat(ctx context.Context, inputCh <-chan string, a sampleAgent, resume * } conn, err := a.StreamBidi(ctx, opts...) if err != nil { - return fmt.Errorf("open agent %q: %w", a.Name(), err) + return prevSessionID, fmt.Errorf("open agent %q: %w", a.Name(), err) } var ( @@ -417,17 +420,28 @@ repl: fmt.Printf("Done (%s). Final snapshot: %s.\n", out.FinishReason, shortID(out.SnapshotID)) } + // Hand back the session this chat ran under so the caller can offer to + // resume it. An invocation that never produced output (e.g. the stream + // errored before any turn) keeps the prior session. + sessionID := prevSessionID + if out != nil && out.SessionID != "" { + sessionID = out.SessionID + } if quit { - return errQuit + return sessionID, errQuit } - return nil + return sessionID, nil } -// summarizeLatest is the one-line summary printed under each agent in -// the list. Empty if there is no snapshot yet, so a freshly-installed -// sample doesn't show clutter. -func summarizeLatest(ctx context.Context, a sampleAgent) string { - latest, err := a.Store.LatestSnapshot(ctx) +// summarizeLatest is the one-line summary printed under each agent in the +// list: the tip of the conversation last run with it this session. Empty +// when none has run yet, or the session has no resumable snapshot, so a +// fresh list shows no clutter. +func summarizeLatest(ctx context.Context, a *aix.Agent[any, any], sessionID string) string { + if sessionID == "" { + return "" + } + latest, err := a.Store().GetLatestSnapshot(ctx, sessionID) if err != nil || latest == nil { return "" } @@ -442,12 +456,21 @@ func summarizeLatest(ctx context.Context, a sampleAgent) string { // waitForFinalize subscribes to a snapshot's status and blocks until it // transitions out of pending. The returned snapshot is the final one (or // nil if it disappeared). OnSnapshotStatusChange yields the current -// status first, so a snapshot that finalized between the directory scan -// and the subscription is observed immediately. -func waitForFinalize(ctx context.Context, store *localstore.FileSessionStore[any], snapshotID string) (*aix.SessionSnapshot[any], error) { +// status first, so a snapshot that finalized just before the +// subscription is observed immediately. +// +// The status subscription is the optional SnapshotAborter half of the +// store contract. A store without it cannot stream background progress, +// so we fall back to reading the snapshot once and returning it as-is. +func waitForFinalize(ctx context.Context, a *aix.Agent[any, any], snapshotID string) (*aix.SessionSnapshot[any], error) { + store := a.Store() + aborter, ok := store.(aix.SnapshotAborter) + if !ok { + return store.GetSnapshot(ctx, snapshotID) + } subCtx, cancel := context.WithCancel(ctx) defer cancel() - statusCh := store.OnSnapshotStatusChange(subCtx, snapshotID) + statusCh := aborter.OnSnapshotStatusChange(subCtx, snapshotID) for { select { case <-ctx.Done(): diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index 11e69867b1..e508c874db 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -79,12 +79,13 @@ func main() { g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) genkit.DefineSchemaFor[ChatPromptInput](g) - // Each define function returns a fully-populated sampleAgent that - // pairs the registered agent with its FileSessionStore and a - // one-line description for the CLI list view. The CLI then calls - // a.Name(), a.StreamBidi(...) on the embedded agent and a.Store - // for snapshot operations. - agents := []sampleAgent{ + // Each define function registers an agent and returns it. The CLI + // drives all three through the same surface: a.Name() and + // a.Desc().Description for the list view, a.StreamBidi(...) to chat, + // and a.Store() for snapshot reads. Nothing the CLI does is tied to a + // concrete store type, so swapping in a different SessionStore would + // not touch a line of it. + agents := []*aix.Agent[any, any]{ defineInlineAgent(g), definePromptAgent(g), defineCustomAgent(g), @@ -102,25 +103,21 @@ func main() { // prompt, appends the conversation history, calls the model, and updates // session state. This is the shortest path from "I want a chat agent" to // a working one. -func defineInlineAgent(g *genkit.Genkit) sampleAgent { +func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any, any] { const name = "pirate" - store := mustStore(name) - return sampleAgent{ - Agent: genkit.DefineAgent(g, name, - aix.FromInline( - ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), - ), - aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + return genkit.DefineAgent(g, name, + aix.FromInline( + ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), ), - Store: store, - Description: "Sarcastic pirate (inline-defined prompt)", - } + aix.WithSessionStore(mustStore(name)), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + aix.WithDescription[any]("Sarcastic pirate (inline-defined prompt)"), + ) } // definePromptAgent demonstrates DefineAgent with aix.FromPrompt. The @@ -132,18 +129,14 @@ func defineInlineAgent(g *genkit.Genkit) sampleAgent { // FromPrompt's argument is the default input passed to the prompt's // Render on every turn; the inline-prompt variant has no per-turn input // of its own. -func definePromptAgent(g *genkit.Genkit) sampleAgent { +func definePromptAgent(g *genkit.Genkit) *aix.Agent[any, any] { const name = "chef" - store := mustStore(name) - return sampleAgent{ - Agent: genkit.DefineAgent(g, name, - aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), - aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), - ), - Store: store, - Description: "Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)", - } + return genkit.DefineAgent(g, name, + aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), + aix.WithSessionStore(mustStore(name)), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + aix.WithDescription[any]("Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)"), + ) } // defineCustomAgent demonstrates DefineCustomAgent. The per-turn function @@ -155,56 +148,54 @@ func definePromptAgent(g *genkit.Genkit) sampleAgent { // // Even with full control over the loop, the framework still owns session // state, snapshot writes, and the detach lifecycle. -func defineCustomAgent(g *genkit.Genkit) sampleAgent { +func defineCustomAgent(g *genkit.Genkit) *aix.Agent[any, any] { const name = "coder" - store := mustStore(name) - return sampleAgent{ - Agent: genkit.DefineCustomAgent(g, name, - func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { - for chunk, err := range genkit.GenerateStream(ctx, g, - ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a senior software engineer. Answer briefly. Use fenced code blocks when showing code."), - ai.WithMessages(sess.Messages()...), - ) { - if err != nil { - return nil, err - } - if chunk.Done { - sess.AddMessages(chunk.Response.Message) - // Report how the turn ended so the framework can - // forward it on the TurnEnd chunk and persist it - // on the snapshot. - return &aix.TurnResult{ - FinishReason: aix.AgentFinishReason(chunk.Response.FinishReason), - }, nil - } - resp.SendModelChunk(chunk.Chunk) + return genkit.DefineCustomAgent(g, name, + func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { + for chunk, err := range genkit.GenerateStream(ctx, g, + ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a senior software engineer. Answer briefly. Use fenced code blocks when showing code."), + ai.WithMessages(sess.Messages()...), + ) { + if err != nil { + return nil, err } - return nil, nil - }); err != nil { - return nil, err + if chunk.Done { + sess.AddMessages(chunk.Response.Message) + // Report how the turn ended so the framework can + // forward it on the TurnEnd chunk and persist it + // on the snapshot. + return &aix.TurnResult{ + FinishReason: aix.AgentFinishReason(chunk.Response.FinishReason), + }, nil + } + resp.SendModelChunk(chunk.Chunk) } - return sess.Result(), nil - }, - aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), - ), - Store: store, - Description: "Concise code helper (custom per-turn loop)", - } + return nil, nil + }); err != nil { + return nil, err + } + return sess.Result(), nil + }, + aix.WithSessionStore(mustStore(name)), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + aix.WithDescription[any]("Concise code helper (custom per-turn loop)"), + ) } // mustStore creates a FileSessionStore rooted at the per-agent dir under // ./.genkit/snapshots/, or exits the process on failure. Used during // agent setup where there's nowhere sensible to return an error. // -// Per-agent dirs keep one agent's history out of another's listing -// (FileSessionStore.LatestSnapshot scans the directory it was given). +// A dir per agent keeps each agent's snapshots on disk separately, which +// is tidy for browsing but not required: resumes are resolved by session +// ID (see SnapshotReader.GetLatestSnapshot), so one shared store would +// work the same. func mustStore(agentName string) *localstore.FileSessionStore[any] { store, err := localstore.NewFileSessionStore[any]("./.genkit/snapshots/" + agentName) if err != nil { From 78648c9dba86d44d9ec90e5cd7c79b04ed4c3d00 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 15:46:25 -0700 Subject: [PATCH 106/141] feat(go/exp): stream custom state as JSON Patch deltas Replace the agent Stream type parameter and the Responder.SendStatus / AgentStreamChunk.Status mechanism with automatic streaming of custom-state mutations as RFC 6902 JSON Patches, porting the custom-state flow from the JS defineAgent implementation (genkit-ai/genkit#5251). Every Session.UpdateCustom mutation now emits an AgentStreamChunk.CustomPatch describing the delta: the first patch of each turn is a whole-document replace that re-bases the client, and subsequent patches are incremental diffs against the last sent value. The diff honors WithStateTransform, so streamed deltas match the full state in turn-end snapshots and final output. AgentConnection applies each delta and exposes the live value via Custom(). Add a small dependency-free RFC 6902 / RFC 6901 implementation (jsonpatch.go) with Diff and ApplyPatch, and collapse Agent[Stream, State] to Agent[State] across the API, routes, and samples. Correctness notes: - JSONPatchOperation.value is not omitempty, so an explicit null operand survives the wire; omitempty cannot tell null from absent and would make a peer applier drop the member instead of setting it to null. - onChange copies only the custom value (not the whole session state) on the no-transform hot path, honoring the transform on full state otherwise. --- genkit-tools/common/src/types/agent.ts | 47 ++- genkit-tools/genkit-schema.json | 41 ++- go/ai/exp/agent.go | 356 ++++++++++++++------ go/ai/exp/agent_test.go | 208 ++++++------ go/ai/exp/custompatch_test.go | 311 +++++++++++++++++ go/ai/exp/gen.go | 57 +++- go/ai/exp/jsonpatch.go | 447 +++++++++++++++++++++++++ go/ai/exp/jsonpatch_test.go | 352 +++++++++++++++++++ go/ai/exp/session.go | 30 +- go/core/schemas.config | 80 ++++- go/genkit/exp/routes.go | 2 +- go/genkit/genkit.go | 16 +- go/samples/basic-agents/cli.go | 18 +- go/samples/basic-agents/main.go | 10 +- 14 files changed, 1741 insertions(+), 234 deletions(-) create mode 100644 go/ai/exp/custompatch_test.go create mode 100644 go/ai/exp/jsonpatch.go create mode 100644 go/ai/exp/jsonpatch_test.go diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 11e8b900e1..0289cb3d6e 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -300,14 +300,57 @@ export const TurnEndSchema = z.object({ }); export type TurnEnd = z.infer; +/** + * Zod schema for the operation kind of a JSON Patch operation (RFC 6902). + */ +export const JsonPatchOpSchema = z.enum([ + 'add', + 'remove', + 'replace', + 'move', + 'copy', + 'test', +]); +export type JsonPatchOp = z.infer; + +/** + * Zod schema for a single RFC 6902 (JSON Patch) operation. + */ +export const JsonPatchOperationSchema = z.object({ + op: JsonPatchOpSchema, + /** A JSON Pointer (RFC 6901) to the target location, e.g. `"/agentStatus"`. */ + path: z.string(), + /** Source pointer; required for `move` and `copy`. */ + from: z.string().optional(), + /** New value; required for `add`, `replace`, and `test`. */ + value: z.any().optional(), +}); +export type JsonPatchOperation = z.infer; + +/** + * Zod schema for an RFC 6902 JSON Patch: an ordered list of operations + * applied in sequence. + */ +export const JsonPatchSchema = z.array(JsonPatchOperationSchema); +export type JsonPatch = z.infer; + /** * Zod schema for agent stream chunk. */ export const AgentStreamChunkSchema = z.object({ /** Generation tokens from the model. */ modelChunk: ModelResponseChunkSchema.optional(), - /** User-defined structured status information. */ - status: z.any().optional(), + /** + * An RFC 6902 JSON Patch describing a delta applied to the session's custom + * state. Emitted automatically whenever the agent mutates custom state, so + * the client can apply it to its tracked copy and keep custom live as the + * turn streams. Pointers are rooted at the custom document (e.g. + * `/agentStatus`), with no `/custom` prefix. The first patch of every turn is + * a whole-document replace at the root pointer (`""`) that re-bases clients + * which may not share the server's baseline; subsequent patches are + * incremental diffs against the last sent value. + */ + customPatch: JsonPatchSchema.optional(), /** A newly produced artifact. */ artifact: ArtifactSchema.optional(), /** diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 65ce9aef18..c1d689caf5 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -164,7 +164,9 @@ "modelChunk": { "$ref": "#/$defs/ModelResponseChunk" }, - "status": {}, + "customPatch": { + "$ref": "#/$defs/JsonPatch" + }, "artifact": { "$ref": "#/$defs/Artifact" }, @@ -208,6 +210,43 @@ ], "additionalProperties": false }, + "JsonPatchOp": { + "type": "string", + "enum": [ + "add", + "remove", + "replace", + "move", + "copy", + "test" + ] + }, + "JsonPatchOperation": { + "type": "object", + "properties": { + "op": { + "$ref": "#/$defs/JsonPatchOp" + }, + "path": { + "type": "string" + }, + "from": { + "type": "string" + }, + "value": {} + }, + "required": [ + "op", + "path" + ], + "additionalProperties": false + }, + "JsonPatch": { + "type": "array", + "items": { + "$ref": "#/$defs/JsonPatchOperation" + } + }, "SessionSnapshot": { "type": "object", "properties": { diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index a73ac12950..5b9e7c6cdb 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -58,6 +58,7 @@ type SessionRunner[State any] struct { TurnIndex int snapshotCallback SnapshotCallback[State] + onStartTurn func() onEndTurn func(ctx context.Context) collectTurnOutput func() any @@ -164,6 +165,11 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont // (trace marshaling, session state, snapshot writes) must work // on private memory rather than race the caller. input = jsonClone(input) + // Mark the start of the turn so the first customPatch emitted during + // it is a whole-document replace that re-bases the client. + if s.onStartTurn != nil { + s.onStartTurn() + } spanMeta := &tracing.SpanMetadata{ Name: fmt.Sprintf("agent/turn/%d", s.TurnIndex), Type: "flowStep", @@ -402,24 +408,19 @@ func (s *SessionRunner[State]) recoverySnapshotID(ctx context.Context) string { // wire; the session-level side effects still apply. Send itself remains // fire-and-forget and returns no error; the user fn is expected to // observe cancellation through its own ctx check and stop producing. -type Responder[Stream any] struct { - in chan<- *AgentStreamChunk[Stream] +type Responder struct { + in chan<- *AgentStreamChunk ctx context.Context // effects applies the chunk's in-process side effects (session // artifact add, turn-chunk accumulation) synchronously in send, in // the sender's goroutine, so reads and snapshots that follow a Send // cannot miss the chunk. - effects func(*AgentStreamChunk[Stream]) + effects func(*AgentStreamChunk) } // SendModelChunk sends a generation chunk (token-level streaming). -func (r Responder[Stream]) SendModelChunk(chunk *ai.ModelResponseChunk) { - r.send(&AgentStreamChunk[Stream]{ModelChunk: chunk}) -} - -// SendStatus sends a user-defined status update. -func (r Responder[Stream]) SendStatus(status Stream) { - r.send(&AgentStreamChunk[Stream]{Status: status}) +func (r Responder) SendModelChunk(chunk *ai.ModelResponseChunk) { + r.send(&AgentStreamChunk{ModelChunk: chunk}) } // SendArtifact sends an artifact to the stream and adds it to the session. @@ -430,8 +431,8 @@ func (r Responder[Stream]) SendStatus(status Stream) { // The session-level side effect happens whether or not detach has landed; // only the wire forward to the client is suppressed post-detach, when // there is no longer a client to receive it. -func (r Responder[Stream]) SendArtifact(artifact *Artifact) { - r.send(&AgentStreamChunk[Stream]{Artifact: artifact}) +func (r Responder) SendArtifact(artifact *Artifact) { + r.send(&AgentStreamChunk{Artifact: artifact}) } // send applies chunk's in-process side effects, then delivers it to the @@ -444,7 +445,7 @@ func (r Responder[Stream]) SendArtifact(artifact *Artifact) { // after workCtx cancellation completes immediately rather than blocking // on a router that has not yet been put into drain mode by a terminal // path. -func (r Responder[Stream]) send(chunk *AgentStreamChunk[Stream]) { +func (r Responder) send(chunk *AgentStreamChunk) { if r.effects != nil { r.effects(chunk) } @@ -456,11 +457,12 @@ func (r Responder[Stream]) send(chunk *AgentStreamChunk[Stream]) { // --- Agent --- -// AgentFunc is the function signature for custom agents. -// Type parameters: -// - Stream: Type for status updates sent via the responder -// - State: Type for user-defined state in snapshots -type AgentFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *SessionRunner[State]) (*AgentResult, error) +// AgentFunc is the function signature for custom agents. The State type +// parameter is the shape of the conversation's custom state (see +// [SessionState.Custom]). The agent streams output through resp and reads or +// mutates state through sess; mutating custom state via [Session.UpdateCustom] +// automatically streams a [AgentStreamChunk.CustomPatch] delta to the client. +type AgentFunc[State any] = func(ctx context.Context, resp Responder, sess *SessionRunner[State]) (*AgentResult, error) // Agent is a bidirectional streaming agent with automatic snapshot management. // @@ -474,8 +476,8 @@ type AgentFunc[Stream, State any] = func(ctx context.Context, resp Responder[Str // register companion actions for the snapshot lifecycle, available via // [Agent.GetSnapshotAction] and [Agent.AbortSnapshotAction] for serving // alongside the agent, and expose the store itself via [Agent.Store]. -type Agent[Stream, State any] struct { - action *core.BidiAction[*AgentInput, *AgentOutput[State], *AgentStreamChunk[Stream], *AgentInit[State]] +type Agent[State any] struct { + action *core.BidiAction[*AgentInput, *AgentOutput[State], *AgentStreamChunk, *AgentInit[State]] // Companion actions, retained so transports can serve them without a // registry lookup. Nil when the corresponding capability is absent; // see newSnapshotActions. @@ -490,7 +492,7 @@ type Agent[Stream, State any] struct { // Name returns the agent's registered name. This is also the name under // which any inline-defined prompt and companion actions (getSnapshot, // abortSnapshot) are registered. -func (a *Agent[Stream, State]) Name() string { +func (a *Agent[State]) Name() string { return a.action.Name() } @@ -503,7 +505,7 @@ func (a *Agent[Stream, State]) Name() string { // Use it to expose snapshot polling over a transport (e.g. mount it with // genkit.Handler next to the agent itself); local Go code should read // from the store directly. -func (a *Agent[Stream, State]) GetSnapshotAction() api.Action { +func (a *Agent[State]) GetSnapshotAction() api.Action { return a.getSnapshot } @@ -516,7 +518,7 @@ func (a *Agent[Stream, State]) GetSnapshotAction() api.Action { // Use it to expose aborting over a transport (e.g. mount it with // genkit.Handler next to the agent itself); local Go code should call the // store's [SnapshotAborter.AbortSnapshot] directly. -func (a *Agent[Stream, State]) AbortSnapshotAction() api.Action { +func (a *Agent[State]) AbortSnapshotAction() api.Action { return a.abortSnapshot } @@ -528,7 +530,7 @@ func (a *Agent[Stream, State]) AbortSnapshotAction() api.Action { // The store is returned as the [SessionStore] interface, not its concrete // type; a caller needing a store-specific capability (e.g. // [SnapshotAborter]) type-asserts for it. -func (a *Agent[Stream, State]) Store() SessionStore[State] { +func (a *Agent[State]) Store() SessionStore[State] { return a.store } @@ -540,7 +542,7 @@ func (a *Agent[Stream, State]) Store() SessionStore[State] { // type-assert to [api.BidiAction] to route session init (the wire // counterpart of [WithSessionID], [WithSnapshotID], and [WithState]), so // satisfying only [api.Action] would silently break session resume. -var _ api.BidiAction = (*Agent[any, any])(nil) +var _ api.BidiAction = (*Agent[any])(nil) // Register registers the agent's run action and any companion actions // (getSnapshot, abortSnapshot) with the registry. Agents defined via @@ -549,7 +551,7 @@ var _ api.BidiAction = (*Agent[any, any])(nil) // inline-defined prompt does not travel: the agent holds it directly, so // execution is unaffected, but the prompt action stays in the registry it // was defined in. -func (a *Agent[Stream, State]) Register(r api.Registry) { +func (a *Agent[State]) Register(r api.Registry) { // Register the wrapped bidi action under the agent key, the same way // every other action registers itself; the registry holds a uniform // api.BidiAction that the reflection servers, ListAgents, and the route @@ -568,20 +570,20 @@ func (a *Agent[Stream, State]) Register(r api.Registry) { } // Desc returns the descriptor of the agent's run action. -func (a *Agent[Stream, State]) Desc() api.ActionDesc { +func (a *Agent[State]) Desc() api.ActionDesc { return a.action.Desc() } // RunJSON runs a one-shot invocation with no init (a fresh session): // input is the turn's [AgentInput] and the result is the final // [AgentOutput]. To supply a session source, use [Agent.RunBidiJSON]. -func (a *Agent[Stream, State]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { +func (a *Agent[State]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { return a.action.RunJSON(ctx, input, cb) } // RunJSONWithTelemetry is [Agent.RunJSON] with trace information on the // result. -func (a *Agent[Stream, State]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Agent[State]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (*api.ActionRunResult[json.RawMessage], error) { return a.action.RunJSONWithTelemetry(ctx, input, cb) } @@ -589,14 +591,14 @@ func (a *Agent[Stream, State]) RunJSONWithTelemetry(ctx context.Context, input j // counterpart of the [InvocationOption] values) rides in opts: input is // delivered as the only chunk on the input stream and outgoing chunks are // forwarded to cb. -func (a *Agent[Stream, State]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *api.BidiSessionOptions) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Agent[State]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *api.BidiSessionOptions) (*api.ActionRunResult[json.RawMessage], error) { return a.action.RunBidiJSON(ctx, input, cb, opts) } // StreamBidiJSON starts a bidirectional streaming session using // JSON-encoded messages. Local Go callers should prefer the typed // [Agent.StreamBidi]. -func (a *Agent[Stream, State]) StreamBidiJSON(ctx context.Context, opts *api.BidiSessionOptions) (api.BidiJSONConnection, error) { +func (a *Agent[State]) StreamBidiJSON(ctx context.Context, opts *api.BidiSessionOptions) (api.BidiJSONConnection, error) { return a.action.StreamBidiJSON(ctx, opts) } @@ -623,7 +625,7 @@ func DefineAgent[State any]( name string, source AgentSource, opts ...AgentOption[State], -) *Agent[any, State] { +) *Agent[State] { switch s := source.(type) { case inlineSource: prompt := ai.DefinePrompt(r, name, s.opts...) @@ -659,11 +661,11 @@ func DefineAgent[State any]( // prompt-backed agent cannot be built before it has one. To get // prompt-like behavior without registration, write a custom agent that // renders and generates with your own [genkit.Genkit] inside fn. -func NewCustomAgent[Stream, State any]( +func NewCustomAgent[State any]( name string, - fn AgentFunc[Stream, State], + fn AgentFunc[State], opts ...AgentOption[State], -) *Agent[Stream, State] { +) *Agent[State] { cfg := &agentOptions[State]{} for _, opt := range opts { if err := opt.applyAgent(cfg); err != nil { @@ -692,7 +694,7 @@ func NewCustomAgent[Stream, State any]( ctx context.Context, in *AgentInit[State], inCh <-chan *AgentInput, - outCh chan<- *AgentStreamChunk[Stream], + outCh chan<- *AgentStreamChunk, ) (*AgentOutput[State], error) { ctx = core.WithFlowContext(ctx, name) rt, err := newAgentRuntime(ctx, name, cfg, in, inCh, outCh) @@ -710,7 +712,7 @@ func NewCustomAgent[Stream, State any]( getSnapshot, abortSnapshot := newSnapshotActions(name, cfg.store, cfg.transform) - return &Agent[Stream, State]{ + return &Agent[State]{ action: action, getSnapshot: getSnapshot, abortSnapshot: abortSnapshot, @@ -727,12 +729,12 @@ func NewCustomAgent[Stream, State any]( // It is [NewCustomAgent] followed by [Agent.Register]. To build an agent // without registering it, use [NewCustomAgent] directly. For agents backed // by a prompt, use [DefineAgent] instead. -func DefineCustomAgent[Stream, State any]( +func DefineCustomAgent[State any]( r api.Registry, name string, - fn AgentFunc[Stream, State], + fn AgentFunc[State], opts ...AgentOption[State], -) *Agent[Stream, State] { +) *Agent[State] { a := NewCustomAgent(name, fn, opts...) a.Register(r) return a @@ -761,13 +763,14 @@ func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { // session, runner, output router, input intake, and the goroutine that runs // the user fn. Its methods implement the three terminal paths the agent can // take: detach, fn-completion, and client-cancel. -type agentRuntime[Stream, State any] struct { +type agentRuntime[State any] struct { name string cfg *agentOptions[State] session *Session[State] sess *SessionRunner[State] - router *chunkRouter[Stream, State] + router *chunkRouter[State] + patcher *customPatcher[State] intake *detachIntake fnDone chan fnDoneResult[State] @@ -780,14 +783,14 @@ type fnDoneResult[State any] struct { err error } -func newAgentRuntime[Stream, State any]( +func newAgentRuntime[State any]( ctx context.Context, name string, cfg *agentOptions[State], in *AgentInit[State], inCh <-chan *AgentInput, - outCh chan<- *AgentStreamChunk[Stream], -) (*agentRuntime[Stream, State], error) { + outCh chan<- *AgentStreamChunk, +) (*agentRuntime[State], error) { session, parent, err := loadSession(ctx, in, cfg.store) if err != nil { return nil, err @@ -818,7 +821,7 @@ func newAgentRuntime[Stream, State any]( session.state.SessionID = uuid.New().String() } - rt := &agentRuntime[Stream, State]{ + rt := &agentRuntime[State]{ name: name, cfg: cfg, session: session, @@ -835,6 +838,15 @@ func newAgentRuntime[Stream, State any]( } rt.sess.collectTurnOutput = func() any { return rt.router.collectTurnChunks() } rt.sess.onEndTurn = rt.emitTurnEnd + // Stream custom-state mutations as customPatch chunks. beginTurn is armed + // per turn by the runner; the session's onCustomChange hook is wired in + // run, once the work context and responder exist. + rt.patcher = &customPatcher[State]{ + transform: cfg.transform, + session: session, + firstInTurn: true, + } + rt.sess.onStartTurn = rt.patcher.beginTurn // The initial state (fresh, client-provided, or loaded from a // snapshot) is the last-good recovery point until a turn completes. rt.sess.recordLastGood() @@ -857,14 +869,14 @@ func newAgentRuntime[Stream, State any]( // observes the suspension under snapMu; the pending row already captures // the invocation and a single finalize rewrite records the cumulative // state once the queued inputs drain). -func (rt *agentRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { +func (rt *agentRuntime[State]) emitTurnEnd(ctx context.Context) { rt.intake.releaseForward() reason := rt.sess.lastTurnFinishReason var snapshotID string if !rt.sess.lastTurnFailed { snapshotID = rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd, reason) } - rt.router.sendChunk(ctx, &AgentStreamChunk[Stream]{TurnEnd: &TurnEnd{ + rt.router.sendChunk(ctx, &AgentStreamChunk{TurnEnd: &TurnEnd{ SnapshotID: snapshotID, FinishReason: reason, }}) @@ -875,13 +887,22 @@ func (rt *agentRuntime[Stream, State]) emitTurnEnd(ctx context.Context) { // workCtx carries the session and is decoupled from clientCtx: pre-detach a // watcher mirrors clientCtx so a disconnect cancels the work; on detach the // watcher exits and the finalizer goroutine owns workCtx until fn returns. -func (rt *agentRuntime[Stream, State]) run( +func (rt *agentRuntime[State]) run( clientCtx context.Context, - fn AgentFunc[Stream, State], + fn AgentFunc[State], ) (*AgentOutput[State], error) { workCtx, cancelWork := context.WithCancel(context.WithoutCancel(clientCtx)) workCtx = NewSessionContext(workCtx, rt.session) + // Wire custom-state streaming now that the work context exists: every + // UpdateCustom mutation during the invocation emits a customPatch chunk + // through the same responder fn uses (so the chunk is accumulated for the + // turn span and forwarded on the wire, dropping post-detach like any + // other chunk). The session mutation itself still applies regardless. + resp := rt.router.responder(workCtx) + rt.patcher.bind(workCtx, resp.send) + rt.session.onCustomChange = rt.patcher.onChange + var detachOnce sync.Once detached := make(chan struct{}) markDetached := func() { detachOnce.Do(func() { close(detached) }) } @@ -920,7 +941,7 @@ func (rt *agentRuntime[Stream, State]) run( fnErr = core.NewError(core.INTERNAL, "agent fn panicked: %v", r) } }() - result, fnErr = fn(workCtx, rt.router.responder(workCtx), rt.sess) + result, fnErr = fn(workCtx, resp, rt.sess) }() rt.fnDone <- fnDoneResult[State]{result: result, err: fnErr} }() @@ -950,7 +971,7 @@ func (rt *agentRuntime[Stream, State]) run( // pending snapshot) and a [SnapshotAborter] (which bundles both abort // triggering and status-change subscription so the runtime can react to // the abort without polling). -func (rt *agentRuntime[Stream, State]) checkDetachCapabilities() error { +func (rt *agentRuntime[State]) checkDetachCapabilities() error { if rt.cfg.store == nil { return core.NewError(core.FAILED_PRECONDITION, "agent %q: detach requires a session store", rt.name) @@ -967,12 +988,12 @@ func (rt *agentRuntime[Stream, State]) checkDetachCapabilities() error { // gone), wait for the intake reader/forwarder to finish, drain fnDone, // and close the router. Returns the fn's result for callers that need // to surface its error. -func (rt *agentRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFunc) fnDoneResult[State] { +func (rt *agentRuntime[State]) drainAndWait(cancelWork context.CancelFunc) fnDoneResult[State] { cancelWork() // Switch the router to discard mode before waiting on fn. Without - // this, a fn mid-SendStatus blocks on the router's r.in receive while - // the router blocks on r.out send (consumer is gone), so fn never - // observes ctx and we deadlock waiting on fnDone. + // this, a fn mid-send (or a customPatch emit) blocks on the router's + // r.in receive while the router blocks on r.out send (consumer is + // gone), so fn never observes ctx and we deadlock waiting on fnDone. rt.router.stopAndWait() rt.intake.stopAndWait() res := <-rt.fnDone @@ -996,7 +1017,7 @@ func (rt *agentRuntime[Stream, State]) drainAndWait(cancelWork context.CancelFun // client drains it, a disconnected client's ctx cancellation trips // forward's ctx arm, and a client that stopped receiving unparks it when // its Output call drains the stream. -func (rt *agentRuntime[Stream, State]) handleFnDone( +func (rt *agentRuntime[State]) handleFnDone( ctx context.Context, cancelWork context.CancelFunc, res fnDoneResult[State], @@ -1055,7 +1076,7 @@ func (rt *agentRuntime[Stream, State]) handleFnDone( // framework-owned SessionID, so the state handed to a client-managed // caller always carries the conversation's identity even if a transform // rewrote or dropped it. Returns nil if state is nil. -func (rt *agentRuntime[Stream, State]) outboundState(ctx context.Context, state *SessionState[State]) *SessionState[State] { +func (rt *agentRuntime[State]) outboundState(ctx context.Context, state *SessionState[State]) *SessionState[State] { out := applyTransform(ctx, rt.cfg.transform, state) if out != nil { out.SessionID = rt.session.SessionID() @@ -1068,7 +1089,7 @@ func (rt *agentRuntime[Stream, State]) outboundState(ctx context.Context, state // and the last-good state (inline when client-managed, behind a recovery // snapshot ID when server-managed). Message and Artifacts are left empty; // they describe the result of a completed run. -func (rt *agentRuntime[Stream, State]) failedOutput(ctx context.Context, cause error) *AgentOutput[State] { +func (rt *agentRuntime[State]) failedOutput(ctx context.Context, cause error) *AgentOutput[State] { out := &AgentOutput[State]{ FinishReason: AgentFinishReasonFailed, Error: core.AsGenkitError(cause), @@ -1089,7 +1110,7 @@ func (rt *agentRuntime[Stream, State]) failedOutput(ctx context.Context, cause e // stops writing to outCh and discards further chunks, whose in-process // side effects (e.g. artifacts added via Responder.SendArtifact) still // apply at Send time, so user code does not have to branch on detach. -func (rt *agentRuntime[Stream, State]) handleDetach( +func (rt *agentRuntime[State]) handleDetach( clientCtx, workCtx context.Context, cancelWork context.CancelFunc, markDetached func(), @@ -1168,7 +1189,7 @@ func (rt *agentRuntime[Stream, State]) handleDetach( // so the read-and-rewrite is one atomic step: if the row has already // transitioned to aborted (a late abort racing this finalize), // SaveSnapshot sees it inside fn and we leave the row untouched. -func (rt *agentRuntime[Stream, State]) finalizePendingSnapshot( +func (rt *agentRuntime[State]) finalizePendingSnapshot( ctx context.Context, pending *SessionSnapshot[State], result *AgentResult, @@ -1359,28 +1380,28 @@ func resumeSessionFrom[State any](s *Session[State], snap *SessionSnapshot[State // keeps draining its input so the user fn never blocks on a responder // send. -type chunkRouter[Stream, State any] struct { +type chunkRouter[State any] struct { ctx context.Context // action context; ends on client disconnect (or completion) - in chan *AgentStreamChunk[Stream] - out chan<- *AgentStreamChunk[Stream] + in chan *AgentStreamChunk + out chan<- *AgentStreamChunk session *Session[State] turnMu sync.Mutex - turnChunks []*AgentStreamChunk[Stream] + turnChunks []*AgentStreamChunk done chan struct{} stopWriting chan struct{} writerStopped chan struct{} } -func startChunkRouter[Stream, State any]( +func startChunkRouter[State any]( ctx context.Context, session *Session[State], - out chan<- *AgentStreamChunk[Stream], -) *chunkRouter[Stream, State] { - r := &chunkRouter[Stream, State]{ + out chan<- *AgentStreamChunk, +) *chunkRouter[State] { + r := &chunkRouter[State]{ ctx: ctx, - in: make(chan *AgentStreamChunk[Stream]), + in: make(chan *AgentStreamChunk), out: out, session: session, done: make(chan struct{}), @@ -1391,7 +1412,7 @@ func startChunkRouter[Stream, State any]( return r } -func (r *chunkRouter[Stream, State]) run() { +func (r *chunkRouter[State]) run() { defer close(r.done) if !r.forward() { // r.in closed while writes were still allowed; nothing left to do. @@ -1414,7 +1435,7 @@ func (r *chunkRouter[Stream, State]) run() { // artifact. The artifact is deep-copied on its way into the session so // the sender's retained pointer (which also rides the wire chunk) cannot // alias live session state. -func (r *chunkRouter[Stream, State]) applySideEffects(chunk *AgentStreamChunk[Stream]) { +func (r *chunkRouter[State]) applySideEffects(chunk *AgentStreamChunk) { if chunk.Artifact != nil { r.session.AddArtifacts(jsonClone(chunk.Artifact)) } @@ -1428,7 +1449,7 @@ func (r *chunkRouter[Stream, State]) applySideEffects(chunk *AgentStreamChunk[St // forward delivers chunks to outCh until told to stop writing, the // action context ends, or r.in closes. Returns true if the router must // keep draining (writes stopped), false if r.in closed. -func (r *chunkRouter[Stream, State]) forward() bool { +func (r *chunkRouter[State]) forward() bool { for { select { case chunk, ok := <-r.in: @@ -1456,8 +1477,8 @@ func (r *chunkRouter[Stream, State]) forward() bool { // synchronously and sends chunks into the router for the wire forward. // The returned Responder's Send methods drop the forward (returning // promptly) when ctx is cancelled. -func (r *chunkRouter[Stream, State]) responder(ctx context.Context) Responder[Stream] { - return Responder[Stream]{in: r.in, ctx: ctx, effects: r.applySideEffects} +func (r *chunkRouter[State]) responder(ctx context.Context) Responder { + return Responder{in: r.in, ctx: ctx, effects: r.applySideEffects} } // sendChunk delivers chunk to the router for producers other than the @@ -1466,7 +1487,7 @@ func (r *chunkRouter[Stream, State]) responder(ctx context.Context) Responder[St // which has none: no artifact, and TurnEnd is excluded from turn-chunk // accumulation) and returns promptly if ctx is cancelled, dropping the // chunk. -func (r *chunkRouter[Stream, State]) sendChunk(ctx context.Context, chunk *AgentStreamChunk[Stream]) { +func (r *chunkRouter[State]) sendChunk(ctx context.Context, chunk *AgentStreamChunk) { select { case r.in <- chunk: case <-ctx.Done(): @@ -1474,7 +1495,7 @@ func (r *chunkRouter[Stream, State]) sendChunk(ctx context.Context, chunk *Agent } // collectTurnChunks returns and resets accumulated turn chunks. -func (r *chunkRouter[Stream, State]) collectTurnChunks() []*AgentStreamChunk[Stream] { +func (r *chunkRouter[State]) collectTurnChunks() []*AgentStreamChunk { r.turnMu.Lock() defer r.turnMu.Unlock() result := r.turnChunks @@ -1485,17 +1506,100 @@ func (r *chunkRouter[Stream, State]) collectTurnChunks() []*AgentStreamChunk[Str // stopAndWait tells the router to stop writing to out and blocks until it // has committed. After it returns, it is safe for the framework to close // out without risking a write-to-closed-channel panic. -func (r *chunkRouter[Stream, State]) stopAndWait() { +func (r *chunkRouter[State]) stopAndWait() { close(r.stopWriting) <-r.writerStopped } // close signals end-of-input and waits for the router to drain. -func (r *chunkRouter[Stream, State]) close() { +func (r *chunkRouter[State]) close() { close(r.in) <-r.done } +// --- customPatcher --- + +// customPatcher streams the agent's custom state to the client as RFC 6902 +// JSON Patches. The runtime wires it to the session's onCustomChange hook so +// every [Session.UpdateCustom] mutation emits a [AgentStreamChunk.CustomPatch] +// describing the delta, exactly as adding an artifact emits an artifact chunk. +// +// The diff is computed on the client-facing custom value (after the configured +// [StateTransform]), so streamed deltas honor redaction and stay consistent +// with the full state in turn-end snapshots and final output. Because a client +// may begin a turn without having loaded the full state, the first patch of +// each turn is a whole-document replace at the root pointer that re-bases it; +// subsequent patches are incremental diffs against the last sent value. +type customPatcher[State any] struct { + transform StateTransform[State] + session *Session[State] + + ctx context.Context // invocation work context, for the transform + send func(*AgentStreamChunk) // forwards the chunk (accumulate + wire) + + mu sync.Mutex + firstInTurn bool + baseline any // last sent custom, normalized; the diff baseline +} + +// bind attaches the invocation's work context and chunk sink. Called once in +// run, before the agent fn (the only producer of custom mutations) starts. +func (p *customPatcher[State]) bind(ctx context.Context, send func(*AgentStreamChunk)) { + p.ctx = ctx + p.send = send +} + +// beginTurn arms the next emitted patch to be a whole-document replace, +// re-basing a client that may not share the server's baseline. Called by the +// runner at the start of every turn. +func (p *customPatcher[State]) beginTurn() { + p.mu.Lock() + p.firstInTurn = true + p.mu.Unlock() +} + +// onChange computes and emits the patch for the current custom state. It is +// invoked (outside the session lock) after every UpdateCustom mutation. The +// state read, diff, baseline update, and send all happen under p.mu so +// concurrent mutations produce a single, consistently ordered patch stream. +func (p *customPatcher[State]) onChange() { + if p.send == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + + // Diff the client-facing custom value (after the transform), matching what + // turn-end snapshots and final output expose. With no transform we only need + // the custom value, so take a custom-only normalized copy instead of + // deep-copying the whole session state (messages and artifacts included) on + // every mutation. With a transform we honor it on the full state, exactly as + // the snapshot and output paths do, so the streamed delta stays consistent + // with them. + var next any + if p.transform == nil { + next = p.session.customJSON() + } else { + var custom any + if t := applyTransform(p.ctx, p.transform, p.session.State()); t != nil { + custom = t.Custom + } + next = normalizeJSON(custom) + } + + var patch JSONPatch + if p.firstInTurn { + patch = JSONPatch{{Op: JSONPatchOpReplace, Path: "", Value: cloneJSON(next)}} + p.firstInTurn = false + } else { + patch = diffValues(p.baseline, next) + } + p.baseline = next + if len(patch) > 0 { + p.send(&AgentStreamChunk{CustomPatch: patch}) + } +} + // --- detachIntake --- // // detachIntake separates eager src reading from runner-paced forwarding, @@ -1786,8 +1890,8 @@ func validateUserMessage(m *ai.Message) error { // defaultInput is the prompt input passed to Render on every turn. It is // nil for inline-defined prompts ([FromInline]), which take no per-turn // input. -func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[any, State] { - return func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*AgentResult, error) { +func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[State] { + return func(ctx context.Context, resp Responder, sess *SessionRunner[State]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if err := validateUserMessage(input.Message); err != nil { return nil, err @@ -1887,10 +1991,10 @@ func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) Ag // StreamBidi starts a new agent invocation with bidirectional streaming. // Use this for multi-turn interactions where you need to send multiple inputs // and receive streaming chunks. For single-turn usage, see Run and RunText. -func (a *Agent[Stream, State]) StreamBidi( +func (a *Agent[State]) StreamBidi( ctx context.Context, opts ...InvocationOption[State], -) (*AgentConnection[Stream, State], error) { +) (*AgentConnection[State], error) { init, err := a.resolveOptions(opts) if err != nil { return nil, err @@ -1899,7 +2003,7 @@ func (a *Agent[Stream, State]) StreamBidi( if err != nil { return nil, err } - return &AgentConnection[Stream, State]{conn: conn}, nil + return &AgentConnection[State]{conn: conn}, nil } // Run starts a single-turn agent invocation with the given input. @@ -1909,7 +2013,7 @@ func (a *Agent[Stream, State]) StreamBidi( // In-band failures (e.g. a failed turn) resolve as a failed [AgentOutput] // rather than an error; a rejected init payload fails with an error, since // the invocation never starts. See [AgentConnection.Output]. -func (a *Agent[Stream, State]) Run( +func (a *Agent[State]) Run( ctx context.Context, input *AgentInput, opts ...InvocationOption[State], @@ -1930,7 +2034,7 @@ func (a *Agent[Stream, State]) Run( // RunText is a convenience method that starts a single-turn agent invocation // with a user text message. It is equivalent to calling Run with an // AgentInput whose Message is a user text message. -func (a *Agent[Stream, State]) RunText( +func (a *Agent[State]) RunText( ctx context.Context, text string, opts ...InvocationOption[State], @@ -1946,7 +2050,7 @@ func (a *Agent[Stream, State]) RunText( // client-managed conversation's identity rides inside the state itself), // while WithSessionID and WithSnapshotID compose as an assertion. // Per-option duplicate checks live in applyInvocation. -func (a *Agent[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*AgentInit[State], error) { +func (a *Agent[State]) resolveOptions(opts []InvocationOption[State]) (*AgentInit[State], error) { invOpts := &invocationOptions[State]{} for _, opt := range opts { if err := opt.applyInvocation(invOpts); err != nil { @@ -1974,8 +2078,16 @@ func (a *Agent[Stream, State]) resolveOptions(opts []InvocationOption[State]) (* // (SendMessage / SendText / SendResume / Detach) and an Output that // always waits for finalization (so detached invocations see the // pending snapshot ID rather than a context-cancellation error). -type AgentConnection[Stream, State any] struct { - conn *core.BidiConnection[*AgentInput, *AgentOutput[State], *AgentStreamChunk[Stream]] +// +// It also tracks the conversation's custom state live: as [AgentConnection.Receive] +// yields chunks, it applies each chunk's [AgentStreamChunk.CustomPatch] to an +// internal copy, exposed by [AgentConnection.Custom], so callers observe custom +// state as it streams without applying patches themselves. +type AgentConnection[State any] struct { + conn *core.BidiConnection[*AgentInput, *AgentOutput[State], *AgentStreamChunk] + + mu sync.Mutex + custom any // live custom state (normalized JSON), updated as patches stream } // Send sends an AgentInput to the agent. The input must not be nil. @@ -1984,7 +2096,7 @@ type AgentConnection[Stream, State any] struct { // fails with an error matching [core.ErrActionCompleted]; the outcome is // on [AgentConnection.Output]. The same applies to the SendMessage, // SendText, SendResume, and Detach helpers. -func (c *AgentConnection[Stream, State]) Send(input *AgentInput) error { +func (c *AgentConnection[State]) Send(input *AgentInput) error { if input == nil { return core.NewError(core.INVALID_ARGUMENT, "agent input must not be nil") } @@ -1992,12 +2104,12 @@ func (c *AgentConnection[Stream, State]) Send(input *AgentInput) error { } // SendMessage sends a message to the agent for one turn. -func (c *AgentConnection[Stream, State]) SendMessage(message *ai.Message) error { +func (c *AgentConnection[State]) SendMessage(message *ai.Message) error { return c.conn.Send(&AgentInput{Message: message}) } // SendText sends a user text message to the agent. -func (c *AgentConnection[Stream, State]) SendText(text string) error { +func (c *AgentConnection[State]) SendText(text string) error { return c.conn.Send(&AgentInput{ Message: ai.NewUserTextMessage(text), }) @@ -2006,7 +2118,7 @@ func (c *AgentConnection[Stream, State]) SendText(text string) error { // SendResume sends a resume payload to continue an interrupted generation. // Construct the payload with [ai.ToolDef.RestartWith] or // [ai.ToolDef.RespondWith] parts. -func (c *AgentConnection[Stream, State]) SendResume(resume *ToolResume) error { +func (c *AgentConnection[State]) SendResume(resume *ToolResume) error { return c.conn.Send(&AgentInput{Resume: resume}) } @@ -2025,12 +2137,12 @@ func (c *AgentConnection[Stream, State]) SendResume(resume *ToolResume) error { // // To send a final input as part of the same wire message, use // Send(&AgentInput{Detach: true, Message: ...}) directly. -func (c *AgentConnection[Stream, State]) Detach() error { +func (c *AgentConnection[State]) Detach() error { return c.conn.Send(&AgentInput{Detach: true}) } // Close signals that no more inputs will be sent. -func (c *AgentConnection[Stream, State]) Close() error { +func (c *AgentConnection[State]) Close() error { return c.conn.Close() } @@ -2040,8 +2152,54 @@ func (c *AgentConnection[Stream, State]) Close() error { // again to consume the next batch. Call [AgentConnection.Output] to // finish the invocation, or cancel the ctx passed to StreamBidi to // abort it. -func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[Stream], error] { - return c.conn.Receive() +// +// Each yielded chunk's [AgentStreamChunk.CustomPatch] is applied to the +// connection's tracked custom state before the chunk is yielded, so +// [AgentConnection.Custom] reflects every delta observed so far. +func (c *AgentConnection[State]) Receive() iter.Seq2[*AgentStreamChunk, error] { + return func(yield func(*AgentStreamChunk, error) bool) { + for chunk, err := range c.conn.Receive() { + if err == nil && chunk != nil && len(chunk.CustomPatch) > 0 { + c.applyCustomPatch(chunk.CustomPatch) + } + if !yield(chunk, err) { + return + } + } + } +} + +// applyCustomPatch applies a streamed patch to the tracked custom state. A +// malformed patch (only possible from a non-conforming server) leaves the last +// good value in place; the next turn's whole-document replace re-bases it. +func (c *AgentConnection[State]) applyCustomPatch(patch JSONPatch) { + c.mu.Lock() + defer c.mu.Unlock() + if next, err := applyOps(cloneJSON(c.custom), patch); err == nil { + c.custom = next + } +} + +// Custom returns the conversation's custom state as tracked from the streamed +// patches observed via [AgentConnection.Receive]. It reflects the deltas +// consumed so far, so reading it as a turn streams shows the live state; before +// any patch arrives it returns the zero value. The authoritative final state is +// on [AgentOutput.State] (client-managed) or the turn-end snapshot +// (server-managed). +func (c *AgentConnection[State]) Custom() (State, error) { + c.mu.Lock() + tree := cloneJSON(c.custom) + c.mu.Unlock() + + var out State + b, err := json.Marshal(tree) + if err != nil { + return out, err + } + if err := json.Unmarshal(b, &out); err != nil { + return out, err + } + return out, nil } // Output finalizes the connection and returns the agent's result. @@ -2067,7 +2225,7 @@ func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[S // Do not call Output concurrently with a goroutine iterating Receive; // both consume from the same stream and chunks would be split between // them. Finish Receive first, then call Output. -func (c *AgentConnection[Stream, State]) Output() (*AgentOutput[State], error) { +func (c *AgentConnection[State]) Output() (*AgentOutput[State], error) { _ = c.conn.Close() // The core connection applies backpressure and its Output does not // consume the stream, so drain the chunks the caller did not Receive; @@ -2080,6 +2238,6 @@ func (c *AgentConnection[Stream, State]) Output() (*AgentOutput[State], error) { } // Done returns a channel closed when the connection completes. -func (c *AgentConnection[Stream, State]) Done() <-chan struct{} { +func (c *AgentConnection[State]) Done() <-chan struct{} { return c.conn.Done() } diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 35281ec3d6..74381c35e3 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -37,10 +37,6 @@ type testState struct { Topics []string `json:"topics,omitempty"` } -type testStatus struct { - Phase string `json:"phase"` -} - func newTestRegistry(t *testing.T) *registry.Registry { t.Helper() return registry.New() @@ -51,19 +47,18 @@ func TestAgent_BasicMultiTurn(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "basicFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - resp.SendStatus(testStatus{Phase: "generating"}) // Echo back the user's message. if input.Message != nil { reply := ai.NewModelTextMessage("echo: " + input.Message.Content[0].Text) sess.AddMessages(reply) } + // Mutating custom state streams a customPatch chunk. sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) - resp.SendStatus(testStatus{Phase: "complete"}) return nil, nil }) }, @@ -88,7 +83,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { break } } - if turn1Chunks < 2 { // at least status + TurnEnd + if turn1Chunks < 2 { // at least customPatch + TurnEnd t.Errorf("expected at least 2 chunks in turn 1, got %d", turn1Chunks) } @@ -127,7 +122,7 @@ func TestAgent_WithSessionStore(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "snapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -197,7 +192,7 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "resumeFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -286,7 +281,7 @@ func TestAgent_ClientManagedState(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "clientStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -349,7 +344,7 @@ func TestAgent_Artifacts(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "artifactFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendArtifact(&Artifact{ @@ -424,7 +419,7 @@ func TestAgent_ClientManagedState_CallerStateIsolated(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "stateIsolationFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { // Replace the artifact the caller's state carried (the // in-place replace path) and extend history. @@ -481,7 +476,7 @@ func TestAgent_InputMessageCloned(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "inputCloneFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { return nil, nil }) @@ -536,7 +531,7 @@ func TestAgent_SendArtifact_SynchronousAndCloned(t *testing.T) { sessionContent string ) af := DefineCustomAgent(reg, "syncArtifactFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { a := &Artifact{Name: "out.txt", Parts: []*ai.Part{ai.NewTextPart("original")}} resp.SendArtifact(a) @@ -575,7 +570,7 @@ func TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "turnEndArtifactFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendArtifact(&Artifact{ Name: "report.md", @@ -621,7 +616,7 @@ func TestAgent_SnapshotCallback(t *testing.T) { // Only snapshot on even turns. callbackCalls := 0 af := DefineCustomAgent(reg, "callbackFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -677,7 +672,7 @@ func TestAgent_SendMessage(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "sendMsgFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { return nil, nil }) @@ -721,7 +716,7 @@ func TestAgent_SessionContext(t *testing.T) { var retrievedCounter int af := DefineCustomAgent(reg, "contextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { // Session should be retrievable from context. ctxSess := SessionFromContext[testState](ctx) @@ -766,7 +761,7 @@ func TestAgent_ErrorInTurn(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "errorFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { return nil, fmt.Errorf("turn failed") }) @@ -807,9 +802,9 @@ func TestAgent_ErrorInTurn(t *testing.T) { // defineLastGoodTestAgent defines a client- or server-managed echo agent // whose turn fails (with partial session mutations) when the user sends // "boom". Successful turns report [AgentFinishReasonStop]. -func defineLastGoodTestAgent(reg api.Registry, name string, opts ...AgentOption[testState]) *Agent[testStatus, testState] { +func defineLastGoodTestAgent(reg api.Registry, name string, opts ...AgentOption[testState]) *Agent[testState] { return DefineCustomAgent(reg, name, - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { text := input.Message.Content[0].Text if text == "boom" { @@ -1123,7 +1118,7 @@ func TestAgent_FailedTurn_EmitsFailedTurnEnd(t *testing.T) { // forwarding chunks once fn returns with an error). turnEndSeen := make(chan struct{}) af := DefineCustomAgent(reg, "failedTurnEnd", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { return nil, fmt.Errorf("boom") }) @@ -1170,7 +1165,7 @@ func TestAgent_CustomAgentContinuesAfterFailedTurn(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "continueAfterFail", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { for { err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { text := input.Message.Content[0].Text @@ -1227,7 +1222,7 @@ func TestAgent_InitFailure_FailsActionWithStatus(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - echo := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + echo := func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { return nil, nil }) @@ -1291,7 +1286,7 @@ func TestAgent_SetMessages(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "setMsgsFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { // Replace all messages with just one. sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) @@ -1334,7 +1329,7 @@ func TestAgent_TurnSpanOutput(t *testing.T) { var capturedOutputs []any af := DefineCustomAgent(reg, "turnOutputFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { // Wrap collectTurnOutput to capture what each turn produces. originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { @@ -1344,7 +1339,10 @@ func TestAgent_TurnSpanOutput(t *testing.T) { } return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - resp.SendStatus(testStatus{Phase: "thinking"}) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) resp.SendModelChunk(&ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("reply")}, }) @@ -1389,11 +1387,11 @@ func TestAgent_TurnSpanOutput(t *testing.T) { } for i, output := range capturedOutputs { - chunks, ok := output.([]*AgentStreamChunk[testStatus]) + chunks, ok := output.([]*AgentStreamChunk) if !ok { - t.Fatalf("turn %d: expected []*AgentStreamChunk[testStatus], got %T", i, output) + t.Fatalf("turn %d: expected []*AgentStreamChunk, got %T", i, output) } - // 3 content chunks per turn: status + model chunk + artifact. + // 3 content chunks per turn: customPatch + model chunk + artifact. if len(chunks) != 3 { t.Errorf("turn %d: expected 3 chunks, got %d", i, len(chunks)) } @@ -1413,7 +1411,7 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { var capturedOutputs []any af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { originalCollect := sess.collectTurnOutput sess.collectTurnOutput = func() any { output := originalCollect() @@ -1422,7 +1420,10 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { } return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - resp.SendStatus(testStatus{Phase: "working"}) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) sess.AddMessages(ai.NewModelTextMessage("reply")) return nil, nil }) @@ -1455,16 +1456,17 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { t.Fatal("expected a snapshot ID on the turn-end chunk") } - // Turn output should contain only the status chunk, not the TurnEnd signal. + // Turn output should contain only the customPatch chunk, not the TurnEnd signal. if len(capturedOutputs) != 1 { t.Fatalf("expected 1 captured output, got %d", len(capturedOutputs)) } - chunks := capturedOutputs[0].([]*AgentStreamChunk[testStatus]) + chunks := capturedOutputs[0].([]*AgentStreamChunk) if len(chunks) != 1 { t.Errorf("expected 1 content chunk, got %d", len(chunks)) } - if chunks[0].Status.Phase != "working" { - t.Errorf("expected status phase 'working', got %q", chunks[0].Status.Phase) + // The first (and only) patch of the turn is a whole-document replace. + if got := chunks[0].CustomPatch; len(got) != 1 || got[0].Op != JSONPatchOpReplace || got[0].Path != "" { + t.Errorf("expected a whole-document replace customPatch, got %+v", got) } } @@ -1902,7 +1904,7 @@ func TestAgent_RunText(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "runTextFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Message.Content[0].Text)) @@ -1935,7 +1937,7 @@ func TestAgent_Run(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "runFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -1965,7 +1967,7 @@ func TestAgent_RunText_WithState(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "runStateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -2006,7 +2008,7 @@ func TestAgent_RunText_WithSnapshot(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "runSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -2161,7 +2163,7 @@ func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "dedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -2205,7 +2207,7 @@ func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "multiDedupFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -2267,7 +2269,7 @@ func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "postRunMutateFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil, nil @@ -2320,9 +2322,11 @@ func TestAgent_FnPanicResolvesAsFailedOutput(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "panicFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - resp.SendStatus(testStatus{Phase: "before-panic"}) + resp.SendModelChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("before-panic")}, + }) panic("boom") }) }, @@ -2362,7 +2366,7 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { emitting := make(chan struct{}) fnDone := make(chan struct{}) af := DefineCustomAgent(reg, "cancelFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { defer close(fnDone) return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { close(emitting) @@ -2375,7 +2379,9 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { return nil, ctx.Err() default: } - resp.SendStatus(testStatus{Phase: "tick"}) + resp.SendModelChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("tick")}, + }) } }) }, @@ -2431,7 +2437,7 @@ func waitForSnapshot[State any]( // a copy of it, failing the test if the stream errors or ends first. Use it // in tests that only need to advance to a turn boundary; tests that must // inspect intermediate chunks should range over Receive directly. -func nextTurnEnd[Stream, State any](t *testing.T, conn *AgentConnection[Stream, State]) *TurnEnd { +func nextTurnEnd[State any](t *testing.T, conn *AgentConnection[State]) *TurnEnd { t.Helper() for chunk, err := range conn.Receive() { if err != nil { @@ -2449,7 +2455,7 @@ func nextTurnEnd[Stream, State any](t *testing.T, conn *AgentConnection[Stream, // outputWithin finalizes conn and returns its output, failing the test if // finalization does not complete within d. Use it in tests where a // regression would make Output hang rather than fail. -func outputWithin[Stream, State any](t *testing.T, conn *AgentConnection[Stream, State], d time.Duration) (*AgentOutput[State], error) { +func outputWithin[State any](t *testing.T, conn *AgentConnection[State], d time.Duration) (*AgentOutput[State], error) { t.Helper() type outcome struct { out *AgentOutput[State] @@ -2476,7 +2482,7 @@ func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "turnEndSnapshotFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil, nil @@ -2542,7 +2548,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { release := make(chan struct{}) af := DefineCustomAgent(reg, "detachInFlight", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { entered <- struct{}{} <-release @@ -2641,7 +2647,7 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { release := make(chan struct{}, 4) af := DefineCustomAgent(reg, "detachChainParent", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { enter <- struct{}{} <-release @@ -2719,7 +2725,7 @@ func TestAgent_Detach_RequiresStore(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "detachNoStore", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { return nil, nil }) @@ -2760,7 +2766,7 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachComplete", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -2855,7 +2861,7 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { release := make(chan struct{}) af := DefineCustomAgent(reg, "detachArtifact", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { resp.SendArtifact(&Artifact{ Name: "before.txt", @@ -2932,7 +2938,7 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { boom := errors.New("kaboom") af := DefineCustomAgent(reg, "detachErr", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -3003,7 +3009,7 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachAbort", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -3074,7 +3080,7 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "syncStillWorks", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil, nil @@ -3132,7 +3138,7 @@ func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { exited := make(chan error, 1) af := DefineCustomAgent(reg, "syncCancel", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -3197,7 +3203,7 @@ func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { } af := DefineCustomAgent(reg, "resumeErrored", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(store), @@ -3236,7 +3242,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } af := DefineCustomAgent(reg, "transformedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("the secret is out")) return nil, nil @@ -3314,7 +3320,7 @@ func TestAgent_GetSnapshotAction_ReturnsFinishReason(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "finishReasonActionFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("done")) return &TurnResult{FinishReason: AgentFinishReasonStop}, nil @@ -3361,7 +3367,7 @@ func TestAgent_GetSnapshotAction_NoStore(t *testing.T) { reg := newTestRegistry(t) DefineCustomAgent(reg, "noStoreFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, ) @@ -3571,7 +3577,7 @@ func TestAgent_AgentMetadata(t *testing.T) { // Verify the metadata["agent"] payload on the agent's action descriptor // correctly reports stateManagement and abortable for each combination // of store capabilities. - noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + noopFn := func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil } @@ -3649,7 +3655,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() // implements SnapshotAborter DefineCustomAgent(reg, "fullCaps", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(store), @@ -3669,7 +3675,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { t.Run("no aborter capability → abort not registered", func(t *testing.T) { reg := newTestRegistry(t) DefineCustomAgent(reg, "minCaps", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore[testState](minimalStore[testState]{}), @@ -3693,7 +3699,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { // the store's capabilities by returning nil for actions that were not // registered. func TestAgent_CompanionActionAccessors(t *testing.T) { - noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + noopFn := func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil } @@ -3737,7 +3743,7 @@ func TestAgent_CompanionActionAccessors(t *testing.T) { // configured with (so local Go code need not thread a separate reference), // and nil when the agent is client-managed. func TestAgent_Store(t *testing.T) { - noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + noopFn := func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil } @@ -3769,7 +3775,7 @@ func TestAgent_Store(t *testing.T) { // metadata["description"] by core), so reflective tooling and local // callers read it the same way they read any other action's description. func TestAgent_Description(t *testing.T) { - noopFn := func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + noopFn := func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil } @@ -3816,7 +3822,7 @@ func TestAgent_Description(t *testing.T) { func TestAgent_RegisterCarriesCompanions(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "mover", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(newTestInMemStore[testState]()), @@ -3852,7 +3858,7 @@ func TestNewCustomAgent_UnregisteredUntilRegister(t *testing.T) { store := newTestInMemStore[testState]() af := NewCustomAgent("standalone", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("hi")) return &TurnResult{FinishReason: AgentFinishReasonStop}, nil @@ -3892,7 +3898,7 @@ func TestAgent_AbortAction_NotFound(t *testing.T) { // action so callers (Dev UI, remote clients) see a proper status. reg := newTestRegistry(t) DefineCustomAgent(reg, "missingFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil }, WithSessionStore(newTestInMemStore[testState]()), @@ -3928,7 +3934,7 @@ func TestAgent_StateTransform_ClientManagedState(t *testing.T) { } af := DefineCustomAgent(reg, "clientXformFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.UpdateCustom(func(s testState) testState { s.Counter = 7 @@ -3959,7 +3965,7 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "resumeDetachedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) sess.UpdateCustom(func(s testState) testState { @@ -4104,7 +4110,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "raceFinalize", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -4236,7 +4242,7 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "abortNoop", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) return nil, nil @@ -4287,7 +4293,7 @@ func TestAgent_ResultAndOutput_IsolatedFromSession(t *testing.T) { ) af := DefineCustomAgent(reg, "isolation", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("session-msg")) sess.AddArtifacts(&Artifact{ @@ -4361,7 +4367,7 @@ func TestAgent_ResultAndOutput_IsolatedFromSession(t *testing.T) { func TestAgent_Name(t *testing.T) { reg := newTestRegistry(t) a := DefineCustomAgent(reg, "name-accessor", - func(ctx context.Context, _ Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, _ Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return sess.Result(), nil }) if got := a.Name(); got != "name-accessor" { @@ -4380,7 +4386,7 @@ func TestAgent_FinishReason_TurnAndInvocation(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "finishReasonFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return &TurnResult{FinishReason: AgentFinishReasonStop}, nil @@ -4430,7 +4436,7 @@ func TestAgent_FinishReason_OmittedWhenNil(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "noReasonFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil, nil @@ -4466,7 +4472,7 @@ func TestAgent_FinishReason_InvocationOverride(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "overrideReasonFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return &TurnResult{FinishReason: AgentFinishReasonStop}, nil @@ -4510,7 +4516,7 @@ func TestAgent_FinishReason_MultiTurnDistinct(t *testing.T) { reasons := []AgentFinishReason{AgentFinishReasonStop, AgentFinishReasonInterrupted} af := DefineCustomAgent(reg, "multiReasonFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return &TurnResult{FinishReason: reasons[sess.TurnIndex]}, nil @@ -4605,7 +4611,7 @@ func TestAgent_Detach_BackgroundWorkSurvivesActionReturn(t *testing.T) { fnSaw := make(chan string, 1) af := DefineCustomAgent(reg, "detachSurviveFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case <-release: @@ -4668,7 +4674,7 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachReasonComplete", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -4725,7 +4731,7 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachReasonFailed", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -4780,7 +4786,7 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachReasonAborted", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -4841,7 +4847,7 @@ func TestAgent_FinishReason_InvocationOverride_Persisted(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "overridePersistedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return &TurnResult{FinishReason: AgentFinishReasonStop}, nil @@ -4888,7 +4894,7 @@ func TestAgent_FinishReason_MultiTurnDistinct_Persisted(t *testing.T) { reasons := []AgentFinishReason{AgentFinishReasonStop, AgentFinishReasonInterrupted} af := DefineCustomAgent(reg, "multiReasonPersistedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return &TurnResult{FinishReason: reasons[sess.TurnIndex]}, nil @@ -4939,7 +4945,7 @@ func TestAgent_FinishReason_OmittedPersisted(t *testing.T) { store := newTestInMemStore[testState]() af := DefineCustomAgent(reg, "noReasonPersistedFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) return nil, nil @@ -5053,7 +5059,7 @@ func TestAgent_Detach_SucceededHonorsResultOverride(t *testing.T) { entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachOverride", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -5172,7 +5178,7 @@ func TestAgent_SessionID_AssignedBeforeFirstSnapshot(t *testing.T) { var fnSawSessionID, ctxSawSessionID string af := DefineCustomAgent(reg, "sessionAlwaysAssigned", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { fnSawSessionID = sess.SessionID() // The ID lives on the session itself, so code holding only the // context-carried session (e.g. a tool) can read it too. @@ -5526,7 +5532,7 @@ func TestAgent_ClientManagedState_SessionIDRoundTrip(t *testing.T) { var fnSawSessionID string af := DefineCustomAgent(reg, "clientPassthroughFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { fnSawSessionID = sess.SessionID() return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.UpdateCustom(func(s testState) testState { @@ -5635,7 +5641,7 @@ func TestAgent_Detach_AssignsSessionID(t *testing.T) { release := make(chan struct{}) entered := make(chan struct{}) af := DefineCustomAgent(reg, "detachSessionFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { select { case entered <- struct{}{}: @@ -6012,7 +6018,7 @@ func TestAgent_SendNilInput_Rejected(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "nilInputFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { if input.Message != nil { sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Message.Text())) @@ -6114,10 +6120,14 @@ func TestAgent_ClientCancelMidStream(t *testing.T) { reg := newTestRegistry(t) af := DefineCustomAgent(reg, "cancelFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - resp.SendStatus(testStatus{Phase: "step0"}) - resp.SendStatus(testStatus{Phase: "step1"}) + resp.SendModelChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("step0")}, + }) + resp.SendModelChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("step1")}, + }) return nil, nil }) }, @@ -6182,7 +6192,7 @@ func TestAgent_OutputUnblocksOnCancel(t *testing.T) { t.Cleanup(func() { close(block) }) // let the stubborn fn unwind af := DefineCustomAgent(reg, "stubbornFlow", - func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*AgentResult, error) { + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { <-block // ignores ctx return nil, nil }, diff --git a/go/ai/exp/custompatch_test.go b/go/ai/exp/custompatch_test.go new file mode 100644 index 0000000000..0a35600cd6 --- /dev/null +++ b/go/ai/exp/custompatch_test.go @@ -0,0 +1,311 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "sync" + "testing" + + "github.com/firebase/genkit/go/ai" +) + +// collectTurnPatches consumes one turn's chunks, returning the customPatch from +// each chunk that carries one (in stream order). Consuming via Receive also +// updates the connection's tracked custom state. +func collectTurnPatches(t *testing.T, conn *AgentConnection[testState]) []JSONPatch { + t.Helper() + var patches []JSONPatch + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if len(chunk.CustomPatch) > 0 { + patches = append(patches, chunk.CustomPatch) + } + if chunk.TurnEnd != nil { + break + } + } + return patches +} + +// TestCustomPatch_PerTurnRebaseAndIncremental verifies that the first patch of +// every turn is a whole-document replace and later patches in the same turn are +// minimal incremental diffs. +func TestCustomPatch_PerTurnRebaseAndIncremental(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "cp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + // Two mutations: first emits a whole-document replace, the + // second an incremental diff against the first. + sess.UpdateCustom(func(s testState) testState { s.Counter++; return s }) + sess.UpdateCustom(func(s testState) testState { s.Counter++; return s }) + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + defer conn.Output() + + // Turn 1. + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + patches := collectTurnPatches(t, conn) + if len(patches) != 2 { + t.Fatalf("turn 1: expected 2 customPatch chunks, got %d", len(patches)) + } + if op := patches[0][0]; len(patches[0]) != 1 || op.Op != JSONPatchOpReplace || op.Path != "" { + t.Errorf("turn 1 first patch: want whole-document replace, got %s", patchString(patches[0])) + } + if op := patches[1][0]; len(patches[1]) != 1 || op.Op != JSONPatchOpReplace || op.Path != "/counter" { + t.Errorf("turn 1 second patch: want replace /counter, got %s", patchString(patches[1])) + } + + // Turn 2: the first patch re-bases the client with another whole-document + // replace even though only /counter changed since the previous turn. + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + patches = collectTurnPatches(t, conn) + if len(patches) != 2 { + t.Fatalf("turn 2: expected 2 customPatch chunks, got %d", len(patches)) + } + if op := patches[0][0]; op.Path != "" { + t.Errorf("turn 2 first patch: want whole-document replace (path \"\"), got %s", patchString(patches[0])) + } +} + +// TestCustomPatch_ClientTracksLiveCustom verifies the connection applies the +// streamed patches so Custom reflects the server's state. +func TestCustomPatch_ClientTracksLiveCustom(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "cp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + s.Topics = append(s.Topics, input.Message.Text()) + return s + }) + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + // Before any patch, Custom is the zero value. + if got, err := conn.Custom(); err != nil || got.Counter != 0 { + t.Errorf("initial Custom = %+v (err %v), want zero value", got, err) + } + + conn.SendText("alpha") + collectTurnPatches(t, conn) + conn.SendText("beta") + collectTurnPatches(t, conn) + + got, err := conn.Custom() + if err != nil { + t.Fatalf("Custom: %v", err) + } + if got.Counter != 2 { + t.Errorf("tracked counter = %d, want 2", got.Counter) + } + if want := []string{"alpha", "beta"}; len(got.Topics) != 2 || got.Topics[0] != want[0] || got.Topics[1] != want[1] { + t.Errorf("tracked topics = %v, want %v", got.Topics, want) + } + + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + // The live-tracked custom agrees with the authoritative final state. + if out.State.Custom.Counter != got.Counter { + t.Errorf("tracked counter %d != final state counter %d", got.Counter, out.State.Custom.Counter) + } +} + +// TestCustomPatch_HonorsStateTransform verifies the diff is computed on the +// client-facing custom value (after WithStateTransform), so redaction reaches +// the wire. +func TestCustomPatch_HonorsStateTransform(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "cp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.UpdateCustom(func(s testState) testState { + s.Counter = 5 + s.Topics = []string{"secret"} + return s + }) + return nil, nil + }) + }, + // Redact Topics on the way out to the client. + WithStateTransform(func(ctx context.Context, st *SessionState[testState]) *SessionState[testState] { + st.Custom.Topics = nil + return st + }), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + defer conn.Output() + + conn.SendText("hi") + collectTurnPatches(t, conn) + + got, err := conn.Custom() + if err != nil { + t.Fatalf("Custom: %v", err) + } + if got.Counter != 5 { + t.Errorf("counter = %d, want 5", got.Counter) + } + if len(got.Topics) != 0 { + t.Errorf("topics should be redacted on the wire, got %v", got.Topics) + } +} + +// TestCustomPatch_ConcurrentMutations exercises the patcher's locking when +// custom state is mutated from several goroutines at once: the streamed patches +// must converge on the final state, and there must be no data race (run with +// -race). The last patcher emit observes all completed increments, so the +// client's tracked custom equals the final counter. +func TestCustomPatch_ConcurrentMutations(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + const n = 30 + + af := DefineCustomAgent(reg, "cp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + var wg sync.WaitGroup + for range n { + wg.Add(1) + go func() { + defer wg.Done() + sess.UpdateCustom(func(s testState) testState { s.Counter++; return s }) + }() + } + wg.Wait() + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + conn.SendText("go") + collectTurnPatches(t, conn) // drains concurrently with the producers + + got, err := conn.Custom() + if err != nil { + t.Fatalf("Custom: %v", err) + } + if got.Counter != n { + t.Errorf("tracked counter = %d, want %d", got.Counter, n) + } + + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.State.Custom.Counter != n { + t.Errorf("final state counter = %d, want %d", out.State.Custom.Counter, n) + } +} + +// TestCustomPatch_NoMutationNoPatch verifies a turn that does not mutate custom +// state emits no customPatch chunk. +func TestCustomPatch_NoMutationNoPatch(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "cp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + defer conn.Output() + + conn.SendText("hi") + if patches := collectTurnPatches(t, conn); len(patches) != 0 { + t.Errorf("expected no customPatch chunks, got %d", len(patches)) + } +} + +// TestCustomPatch_EmptyDiffEmitsNothing verifies that a no-op mutation after +// the first patch of the turn produces no chunk (an empty incremental diff). +func TestCustomPatch_EmptyDiffEmitsNothing(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "cp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.UpdateCustom(func(s testState) testState { s.Counter = 1; return s }) + // Re-applies the same value: the incremental diff is empty. + sess.UpdateCustom(func(s testState) testState { s.Counter = 1; return s }) + return nil, nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + defer conn.Output() + + conn.SendText("hi") + patches := collectTurnPatches(t, conn) + if len(patches) != 1 { + t.Errorf("expected 1 customPatch chunk (the whole-document replace), got %d", len(patches)) + } +} diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 24fa3f976d..b7b70521b1 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -242,14 +242,24 @@ const ( // AgentStreamChunk represents a single item in the agent's output stream. // Multiple fields can be populated in a single chunk. -type AgentStreamChunk[Stream any] struct { +type AgentStreamChunk struct { // Artifact contains a newly produced artifact. Artifact *Artifact `json:"artifact,omitempty"` + // CustomPatch is an RFC 6902 JSON Patch describing a delta applied to the + // session's custom state. The runtime emits it automatically whenever the + // agent mutates custom state (e.g. via [Session.UpdateCustom]); agents do not + // hand-craft patches. Pointers are rooted at the custom document (e.g. + // "/agentStatus"), with no "/custom" prefix. The first patch of every turn is a + // whole-document replace at the root pointer ("") that re-bases clients which + // may not share the server's baseline; subsequent patches are incremental diffs + // against the last sent value. The diff is computed on the client-facing custom + // state (after any [WithStateTransform]), so streamed deltas honor redaction and + // stay consistent with the full state in turn-end snapshots and final output. + // Apply it with [ApplyPatch] to keep a local copy of custom live as the turn + // streams. + CustomPatch JSONPatch `json:"customPatch,omitempty"` // ModelChunk contains generation tokens from the model. ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` - // Status contains user-defined structured status information. - // The Stream type parameter defines the shape of this data. - Status Stream `json:"status,omitempty"` // TurnEnd is non-nil when the agent has finished processing the current // input. It groups all turn-end signals (snapshot ID, etc.) so callers can // check a single field. When set, the client should stop iterating and may @@ -279,6 +289,45 @@ type GetSnapshotRequest struct { SnapshotID string `json:"snapshotId"` } +// JSONPatch is an RFC 6902 JSON Patch: an ordered list of operations applied in +// sequence. Use [Diff] to compute the patch between two values and [ApplyPatch] +// to apply one to a document. +type JSONPatch []*JSONPatchOperation + +// JSONPatchOp is the kind of a JSON Patch operation (RFC 6902): one of +// [JSONPatchOpAdd], [JSONPatchOpRemove], [JSONPatchOpReplace], [JSONPatchOpMove], +// [JSONPatchOpCopy], or [JSONPatchOpTest]. +type JSONPatchOp string + +const ( + JSONPatchOpAdd JSONPatchOp = "add" + JSONPatchOpRemove JSONPatchOp = "remove" + JSONPatchOpReplace JSONPatchOp = "replace" + JSONPatchOpMove JSONPatchOp = "move" + JSONPatchOpCopy JSONPatchOp = "copy" + JSONPatchOpTest JSONPatchOp = "test" +) + +// JSONPatchOperation is a single RFC 6902 (JSON Patch) operation. A [JSONPatch] +// applies an ordered list of these to transform one JSON document into another. +type JSONPatchOperation struct { + // From is a JSON Pointer to the source location; required for "move" and "copy". + From string `json:"from,omitempty"` + // Op is the operation to perform. + Op JSONPatchOp `json:"op"` + // Path is a JSON Pointer (RFC 6901) to the target location, e.g. "/agentStatus". + // The empty pointer "" refers to the whole document. It must always be present on + // the wire (a whole-document replace carries path ""), so it is not omitted when + // empty. + Path string `json:"path"` + // Value is the operand for "add", "replace", and "test". It is not omitted when + // null so an explicit null operand survives the wire (omitempty cannot tell a + // null operand from an absent one, and dropping it makes a peer applier set the + // member to undefined or remove it instead of null); for "remove", "move", and + // "copy" it is null and ignored. + Value any `json:"value"` +} + // SessionSnapshot is a persisted point-in-time capture of session state. It // is the canonical record written to and read from a [SessionStore]. type SessionSnapshot[State any] struct { diff --git a/go/ai/exp/jsonpatch.go b/go/ai/exp/jsonpatch.go new file mode 100644 index 0000000000..4806cb27a8 --- /dev/null +++ b/go/ai/exp/jsonpatch.go @@ -0,0 +1,447 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "encoding/json" + "fmt" + "reflect" + "sort" + "strconv" + "strings" +) + +// A small, dependency-free RFC 6902 (JSON Patch) / RFC 6901 (JSON Pointer) +// implementation. Genkit uses it to stream incremental changes to a session's +// custom state (see [AgentStreamChunk.CustomPatch]): the runtime [Diff]s the +// client-facing custom value against the last sent one and emits the result, +// and a client (or [AgentConnection]) [ApplyPatch]es each delta to keep a live +// local copy. +// +// [Diff] emits a valid RFC 6902 subset (only add / remove / replace; move / +// copy are optional optimizations we skip). [ApplyPatch] understands the full +// operation set for interoperability and is deliberately lenient so a stream of +// deltas stays robust. +// +// Both operate on JSON-shaped values: the map[string]any / []any / float64 / +// string / bool / nil tree produced by unmarshaling into an any. Inputs are +// normalized (round-tripped through JSON) on the way in, so any +// JSON-serializable Go value may be passed. + +// Diff computes an RFC 6902 JSON Patch that transforms from into to. +// +// The diff is rooted at the document, so pointers are bare (e.g. +// "/agentStatus", "/items/0"). Only add, remove, and replace operations are +// emitted. Object members recurse (add for new keys, remove for deleted keys); +// arrays diff by index, appending with the end-of-array token "/-" and removing +// from the tail backwards so indices stay valid. A root-level change that +// cannot be expressed member-by-member (object↔array↔primitive) collapses to a +// single whole-document replace at path "". +// +// Object keys are visited in sorted order, so the patch is deterministic. +func Diff(from, to any) JSONPatch { + return diffValues(normalizeJSON(from), normalizeJSON(to)) +} + +// diffValues diffs two already-normalized JSON values. The runtime uses it +// directly with a cached, already-normalized baseline to avoid re-normalizing +// it every turn. +func diffValues(from, to any) JSONPatch { + var patch JSONPatch + diffWalk(from, to, "", &patch) + return patch +} + +func diffWalk(from, to any, pointer string, patch *JSONPatch) { + if jsonEqual(from, to) { + return + } + + // Both objects: recurse member by member. + fromObj, fromIsObj := from.(map[string]any) + toObj, toIsObj := to.(map[string]any) + if fromIsObj && toIsObj { + for _, key := range unionKeys(fromObj, toObj) { + child := pointer + "/" + escapeToken(key) + fv, inFrom := fromObj[key] + tv, inTo := toObj[key] + switch { + case inFrom && !inTo: + *patch = append(*patch, &JSONPatchOperation{Op: JSONPatchOpRemove, Path: child}) + case !inFrom && inTo: + *patch = append(*patch, &JSONPatchOperation{Op: JSONPatchOpAdd, Path: child, Value: cloneJSON(tv)}) + default: + diffWalk(fv, tv, child, patch) + } + } + return + } + + // Both arrays: recurse by index, then add/remove the tail difference. + fromArr, fromIsArr := from.([]any) + toArr, toIsArr := to.([]any) + if fromIsArr && toIsArr { + common := min(len(fromArr), len(toArr)) + for i := 0; i < common; i++ { + diffWalk(fromArr[i], toArr[i], pointer+"/"+strconv.Itoa(i), patch) + } + // Appended elements use the "-" end-of-array token. + for i := len(fromArr); i < len(toArr); i++ { + *patch = append(*patch, &JSONPatchOperation{Op: JSONPatchOpAdd, Path: pointer + "/-", Value: cloneJSON(toArr[i])}) + } + // Removals from the tail backwards so earlier indices stay valid. + for i := len(fromArr) - 1; i >= len(toArr); i-- { + *patch = append(*patch, &JSONPatchOperation{Op: JSONPatchOpRemove, Path: pointer + "/" + strconv.Itoa(i)}) + } + return + } + + // Type mismatch or differing primitives: replace at this location. + *patch = append(*patch, &JSONPatchOperation{Op: JSONPatchOpReplace, Path: pointer, Value: cloneJSON(to)}) +} + +// ApplyPatch applies an RFC 6902 JSON Patch to document and returns the new +// value. The input is not mutated; a normalized clone is patched and returned. +// Operating on the root pointer ("") replaces or removes the whole document. +// +// Apply is lenient to keep streaming robust: an add or replace whose parent +// container is missing initializes it as an object, and a remove or replace of +// a missing member is a no-op. A test operation is honored and returns an error +// on mismatch. Other unknown operations also return an error. +// +// Apply diverges from the JS reference applier (applyPatch in json-patch.ts) +// only on inputs [Diff] never emits: an add or replace at an out-of-range array +// index is a no-op here (JS splices/grows the array, back-filling with null), +// and a test against a missing path may pass here where JS throws. The runtime +// applies only Diff output, which is always in range and never a test, so the +// server and a JS client agree on every patch the streaming protocol produces. +func ApplyPatch(document any, patch JSONPatch) (any, error) { + return applyOps(cloneJSON(normalizeJSON(document)), patch) +} + +// applyOps applies patch to an already-normalized doc, mutating it in place +// where possible and returning the result (which may be a fresh value for +// root-level operations). The runtime/client pass a clone they own. +func applyOps(doc any, patch JSONPatch) (any, error) { + for _, op := range patch { + if op == nil { + continue + } + var err error + doc, err = applyOp(doc, op) + if err != nil { + return nil, err + } + } + return doc, nil +} + +func applyOp(doc any, op *JSONPatchOperation) (any, error) { + tokens, err := parsePointer(op.Path) + if err != nil { + return nil, err + } + switch op.Op { + case JSONPatchOpAdd: + return setPath(doc, tokens, normalizeJSON(op.Value), true), nil + case JSONPatchOpReplace: + return setPath(doc, tokens, normalizeJSON(op.Value), false), nil + case JSONPatchOpRemove: + return removePath(doc, tokens), nil + case JSONPatchOpTest: + if !jsonEqual(getPath(doc, tokens), normalizeJSON(op.Value)) { + return nil, fmt.Errorf("jsonpatch: test failed at %q", op.Path) + } + return doc, nil + case JSONPatchOpMove: + fromTokens, err := parsePointer(op.From) + if err != nil { + return nil, err + } + v := cloneJSON(getPath(doc, fromTokens)) + doc = removePath(doc, fromTokens) + return setPath(doc, tokens, v, true), nil + case JSONPatchOpCopy: + fromTokens, err := parsePointer(op.From) + if err != nil { + return nil, err + } + return setPath(doc, tokens, cloneJSON(getPath(doc, fromTokens)), true), nil + default: + return nil, fmt.Errorf("jsonpatch: unsupported op %q", op.Op) + } +} + +// setPath sets value at tokens within node, creating missing intermediate +// objects, and returns the (possibly new) node. isAdd inserts into arrays +// (vs. replacing an element) and appends on the "-" token. +func setPath(node any, tokens []string, value any, isAdd bool) any { + if len(tokens) == 0 { + return value // root add/replace + } + // Lenient: initialize a missing/null container so member sets still land. + if node == nil { + node = map[string]any{} + } + token := tokens[0] + if len(tokens) == 1 { + return setMember(node, token, value, isAdd) + } + switch n := node.(type) { + case map[string]any: + child, ok := n[token] + if !ok || !isContainer(child) { + child = map[string]any{} + } + n[token] = setPath(child, tokens[1:], value, isAdd) + return n + case []any: + idx, ok := arrayIndex(token, len(n), false) + if !ok || idx >= len(n) { + return n // lenient: nothing to descend into + } + if !isContainer(n[idx]) { + n[idx] = map[string]any{} + } + n[idx] = setPath(n[idx], tokens[1:], value, isAdd) + return n + default: + // Primitive where a container was expected: replace it with one. + return setPath(map[string]any{}, tokens, value, isAdd) + } +} + +// setMember sets the leaf member token on node, returning the (possibly new) +// node. +func setMember(node any, token string, value any, isAdd bool) any { + switch n := node.(type) { + case map[string]any: + n[token] = value + return n + case []any: + if token == "-" { + return append(n, value) + } + idx, ok := arrayIndex(token, len(n), isAdd) + if !ok { + return n + } + if isAdd { + if idx > len(n) { + return n + } + n = append(n, nil) + copy(n[idx+1:], n[idx:]) + n[idx] = value + return n + } + if idx >= 0 && idx < len(n) { + n[idx] = value + } + return n + default: + return map[string]any{token: value} + } +} + +// removePath deletes the member at tokens within node and returns the +// (possibly new) node. Missing members are a no-op. +func removePath(node any, tokens []string) any { + if len(tokens) == 0 { + return nil // remove whole document + } + if node == nil { + return nil + } + token := tokens[0] + if len(tokens) == 1 { + switch n := node.(type) { + case map[string]any: + delete(n, token) + return n + case []any: + idx, ok := arrayIndex(token, len(n), false) + if ok && idx >= 0 && idx < len(n) { + return append(n[:idx], n[idx+1:]...) + } + return n + default: + return n + } + } + switch n := node.(type) { + case map[string]any: + if child, ok := n[token]; ok { + n[token] = removePath(child, tokens[1:]) + } + return n + case []any: + idx, ok := arrayIndex(token, len(n), false) + if ok && idx >= 0 && idx < len(n) { + n[idx] = removePath(n[idx], tokens[1:]) + } + return n + default: + return n + } +} + +// getPath reads the value at tokens, returning nil for any missing segment. +func getPath(node any, tokens []string) any { + cur := node + for _, t := range tokens { + switch n := cur.(type) { + case map[string]any: + cur = n[t] + case []any: + idx, ok := arrayIndex(t, len(n), false) + if !ok || idx < 0 || idx >= len(n) { + return nil + } + cur = n[idx] + default: + return nil + } + } + return cur +} + +// arrayIndex parses an array reference token. The "-" end-of-array token +// resolves to length only when allowEnd is set (an insert/append position); +// otherwise it is rejected. Non-numeric or negative tokens are rejected. +func arrayIndex(token string, length int, allowEnd bool) (int, bool) { + if token == "-" { + if allowEnd { + return length, true + } + return 0, false + } + idx, err := strconv.Atoi(token) + if err != nil || idx < 0 { + return 0, false + } + return idx, true +} + +// --- JSON Pointer (RFC 6901) --- + +// parsePointer splits a JSON Pointer into its reference tokens. The root +// pointer "" yields an empty slice. +func parsePointer(pointer string) ([]string, error) { + if pointer == "" { + return nil, nil + } + if pointer[0] != '/' { + return nil, fmt.Errorf("jsonpatch: invalid JSON Pointer %q: must start with %q", pointer, "/") + } + parts := strings.Split(pointer[1:], "/") + for i, p := range parts { + parts[i] = unescapeToken(p) + } + return parts, nil +} + +// escapeToken escapes a single reference token per RFC 6901 ("~" → "~0", +// "/" → "~1"). The "~" replacement runs first so a literal "/" does not become +// "~01". +func escapeToken(token string) string { + return strings.ReplaceAll(strings.ReplaceAll(token, "~", "~0"), "/", "~1") +} + +// unescapeToken reverses escapeToken ("~1" → "/", "~0" → "~"), with "~1" first +// so "~01" decodes to "~1" rather than "/". +func unescapeToken(token string) string { + return strings.ReplaceAll(strings.ReplaceAll(token, "~1", "/"), "~0", "~") +} + +// --- JSON value helpers --- + +func isContainer(v any) bool { + switch v.(type) { + case map[string]any, []any: + return true + default: + return false + } +} + +// unionKeys returns the sorted union of two objects' keys, for deterministic +// diff output. +func unionKeys(a, b map[string]any) []string { + seen := make(map[string]struct{}, len(a)+len(b)) + keys := make([]string, 0, len(a)+len(b)) + for k := range a { + if _, ok := seen[k]; !ok { + seen[k] = struct{}{} + keys = append(keys, k) + } + } + for k := range b { + if _, ok := seen[k]; !ok { + seen[k] = struct{}{} + keys = append(keys, k) + } + } + sort.Strings(keys) + return keys +} + +// jsonEqual reports deep equality of two normalized JSON values. +func jsonEqual(a, b any) bool { + return reflect.DeepEqual(a, b) +} + +// normalizeJSON round-trips v through JSON so it becomes the canonical +// map[string]any / []any / float64 / string / bool / nil shape the diff and +// apply logic operate on. A value that cannot be marshaled is returned +// unchanged (best effort); JSON-shaped inputs never hit that path. +func normalizeJSON(v any) any { + if v == nil { + return nil + } + b, err := json.Marshal(v) + if err != nil { + return v + } + var out any + if err := json.Unmarshal(b, &out); err != nil { + return v + } + return out +} + +// cloneJSON deep-copies a normalized JSON value. Primitives are immutable and +// returned as-is; maps and slices are copied recursively so callers cannot +// alias each other's state. +func cloneJSON(v any) any { + switch t := v.(type) { + case map[string]any: + out := make(map[string]any, len(t)) + for k, val := range t { + out[k] = cloneJSON(val) + } + return out + case []any: + out := make([]any, len(t)) + for i, val := range t { + out[i] = cloneJSON(val) + } + return out + default: + return v + } +} diff --git a/go/ai/exp/jsonpatch_test.go b/go/ai/exp/jsonpatch_test.go new file mode 100644 index 0000000000..8ae3d5d6f0 --- /dev/null +++ b/go/ai/exp/jsonpatch_test.go @@ -0,0 +1,352 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" +) + +func TestDiff_Shapes(t *testing.T) { + tests := []struct { + name string + from any + to any + want JSONPatch + }{ + { + name: "equal yields empty patch", + from: map[string]any{"a": 1}, + to: map[string]any{"a": 1}, + want: nil, + }, + { + name: "add member", + from: map[string]any{"a": 1}, + to: map[string]any{"a": 1, "b": 2}, + want: JSONPatch{{Op: JSONPatchOpAdd, Path: "/b", Value: float64(2)}}, + }, + { + name: "remove member", + from: map[string]any{"a": 1, "b": 2}, + to: map[string]any{"a": 1}, + want: JSONPatch{{Op: JSONPatchOpRemove, Path: "/b"}}, + }, + { + name: "replace member", + from: map[string]any{"a": 1}, + to: map[string]any{"a": 2}, + want: JSONPatch{{Op: JSONPatchOpReplace, Path: "/a", Value: float64(2)}}, + }, + { + name: "keys are sorted for determinism", + from: map[string]any{}, + to: map[string]any{"b": 1, "a": 2}, + want: JSONPatch{ + {Op: JSONPatchOpAdd, Path: "/a", Value: float64(2)}, + {Op: JSONPatchOpAdd, Path: "/b", Value: float64(1)}, + }, + }, + { + name: "nested member recurses", + from: map[string]any{"o": map[string]any{"x": 1}}, + to: map[string]any{"o": map[string]any{"x": 2}}, + want: JSONPatch{{Op: JSONPatchOpReplace, Path: "/o/x", Value: float64(2)}}, + }, + { + name: "array append uses end-of-array token", + from: []any{"a"}, + to: []any{"a", "b"}, + want: JSONPatch{{Op: JSONPatchOpAdd, Path: "/-", Value: "b"}}, + }, + { + name: "array tail removal is backwards", + from: []any{"a", "b", "c"}, + to: []any{"a"}, + want: JSONPatch{ + {Op: JSONPatchOpRemove, Path: "/2"}, + {Op: JSONPatchOpRemove, Path: "/1"}, + }, + }, + { + name: "array element change replaces by index", + from: []any{"a", "b"}, + to: []any{"a", "z"}, + want: JSONPatch{{Op: JSONPatchOpReplace, Path: "/1", Value: "z"}}, + }, + { + name: "root type change collapses to whole-doc replace", + from: map[string]any{"a": 1}, + to: []any{1, 2}, + want: JSONPatch{{Op: JSONPatchOpReplace, Path: "", Value: []any{float64(1), float64(2)}}}, + }, + { + name: "escapes pointer tokens", + from: map[string]any{}, + to: map[string]any{"a/b": 1, "m~n": 2}, + want: JSONPatch{ + {Op: JSONPatchOpAdd, Path: "/a~1b", Value: float64(1)}, + {Op: JSONPatchOpAdd, Path: "/m~0n", Value: float64(2)}, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := Diff(tc.from, tc.to) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("Diff()\n got %s\n want %s", patchString(got), patchString(tc.want)) + } + }) + } +} + +// TestDiffApply_RoundTrip is the core property: applying Diff(a, b) to a always +// yields b (normalized). +func TestDiffApply_RoundTrip(t *testing.T) { + cases := []struct { + name string + a, b any + }{ + {"both nil", nil, nil}, + {"nil to object", nil, map[string]any{"a": 1}}, + {"object to nil", map[string]any{"a": 1}, nil}, + {"add and remove and change", map[string]any{"a": 1, "b": 2}, map[string]any{"a": 9, "c": 3}}, + {"deep nesting", map[string]any{"o": map[string]any{"p": map[string]any{"q": 1}}}, map[string]any{"o": map[string]any{"p": map[string]any{"q": 2, "r": 3}}}}, + {"array grow", []any{1, 2}, []any{1, 2, 3, 4}}, + {"array shrink", []any{1, 2, 3, 4}, []any{1}}, + {"array of objects", []any{map[string]any{"k": 1}}, []any{map[string]any{"k": 2}, map[string]any{"k": 3}}}, + {"object holding array", map[string]any{"xs": []any{1, 2}}, map[string]any{"xs": []any{1, 9, 3}}}, + {"primitive change", "hello", "world"}, + {"type flip primitive to object", "x", map[string]any{"a": 1}}, + {"agentStatus style", map[string]any{"agentStatus": "step 1"}, map[string]any{"agentStatus": "step 3 of 12", "done": false}}, + {"null value", map[string]any{"a": 1, "b": nil}, map[string]any{"a": nil, "b": 2}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + patch := Diff(tc.a, tc.b) + got, err := ApplyPatch(tc.a, patch) + if err != nil { + t.Fatalf("ApplyPatch: %v", err) + } + want := normalizeJSON(tc.b) + if !reflect.DeepEqual(got, want) { + t.Errorf("round trip mismatch\n got %#v\n want %#v\n patch %s", got, want, patchString(patch)) + } + }) + } +} + +// TestDiff_NullOperandSurvivesWire guards against json omitempty dropping a +// null operand: a replace (or add) to JSON null must serialize with an explicit +// "value":null, otherwise a peer applier (e.g. the JS client) reads it as absent +// and removes the member instead of setting it to null. The in-memory round-trip +// test cannot catch this because a missing value decodes back to nil Go-side. +func TestDiff_NullOperandSurvivesWire(t *testing.T) { + patch := Diff( + map[string]any{"a": 1, "b": 2}, + map[string]any{"a": nil, "b": 2}, + ) + wire, err := json.Marshal(patch) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if !strings.Contains(string(wire), `"value":null`) { + t.Fatalf("null operand dropped from wire: %s", wire) + } + + // Decode as a peer would and apply: the member must be present and null. + var decoded JSONPatch + if err := json.Unmarshal(wire, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + got, err := ApplyPatch(map[string]any{"a": 1, "b": 2}, decoded) + if err != nil { + t.Fatalf("apply: %v", err) + } + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("apply result is %T, want map", got) + } + if v, present := gotMap["a"]; !present || v != nil { + t.Errorf("member \"a\" = %v (present %v), want present and null", v, present) + } +} + +func TestApplyPatch_RootOps(t *testing.T) { + // Whole-document replace re-bases any prior value. + got, err := ApplyPatch(map[string]any{"old": true}, JSONPatch{ + {Op: JSONPatchOpReplace, Path: "", Value: map[string]any{"new": 1}}, + }) + if err != nil { + t.Fatalf("ApplyPatch: %v", err) + } + if want := map[string]any{"new": float64(1)}; !reflect.DeepEqual(got, want) { + t.Errorf("root replace = %#v, want %#v", got, want) + } + + // Replace at root onto a nil document initializes it. + got, err = ApplyPatch(nil, JSONPatch{{Op: JSONPatchOpReplace, Path: "", Value: "v"}}) + if err != nil { + t.Fatalf("ApplyPatch: %v", err) + } + if got != "v" { + t.Errorf("root replace onto nil = %#v, want %q", got, "v") + } +} + +func TestApplyPatch_Lenient(t *testing.T) { + // add onto a missing/nil document initializes the root container. + got, err := ApplyPatch(nil, JSONPatch{{Op: JSONPatchOpAdd, Path: "/agentStatus", Value: "x"}}) + if err != nil { + t.Fatalf("ApplyPatch: %v", err) + } + if want := map[string]any{"agentStatus": "x"}; !reflect.DeepEqual(got, want) { + t.Errorf("add onto nil = %#v, want %#v", got, want) + } + + // Missing intermediate parents are created as objects. + got, err = ApplyPatch(map[string]any{}, JSONPatch{{Op: JSONPatchOpAdd, Path: "/a/b/c", Value: 1}}) + if err != nil { + t.Fatalf("ApplyPatch: %v", err) + } + want := map[string]any{"a": map[string]any{"b": map[string]any{"c": float64(1)}}} + if !reflect.DeepEqual(got, want) { + t.Errorf("nested add = %#v, want %#v", got, want) + } + + // remove of a missing member is a no-op. + got, err = ApplyPatch(map[string]any{"a": 1}, JSONPatch{{Op: JSONPatchOpRemove, Path: "/missing"}}) + if err != nil { + t.Fatalf("ApplyPatch: %v", err) + } + if w := map[string]any{"a": float64(1)}; !reflect.DeepEqual(got, w) { + t.Errorf("remove missing = %#v, want %#v", got, w) + } + + // replace of a missing member is a no-op (not an error). + if _, err := ApplyPatch(map[string]any{"a": 1}, JSONPatch{{Op: JSONPatchOpReplace, Path: "/missing", Value: 2}}); err != nil { + t.Errorf("replace missing should be a no-op, got error: %v", err) + } +} + +func TestApplyPatch_DoesNotMutateInput(t *testing.T) { + in := map[string]any{"a": float64(1)} + if _, err := ApplyPatch(in, JSONPatch{{Op: JSONPatchOpReplace, Path: "/a", Value: 2}}); err != nil { + t.Fatalf("ApplyPatch: %v", err) + } + if in["a"] != float64(1) { + t.Errorf("input was mutated: %#v", in) + } +} + +func TestApplyPatch_TestOp(t *testing.T) { + doc := map[string]any{"a": 1} + if _, err := ApplyPatch(doc, JSONPatch{{Op: JSONPatchOpTest, Path: "/a", Value: 1}}); err != nil { + t.Errorf("matching test should pass, got: %v", err) + } + if _, err := ApplyPatch(doc, JSONPatch{{Op: JSONPatchOpTest, Path: "/a", Value: 2}}); err == nil { + t.Error("mismatching test should error") + } +} + +func TestApplyPatch_MoveCopy(t *testing.T) { + got, err := ApplyPatch(map[string]any{"a": 1}, JSONPatch{{Op: JSONPatchOpMove, From: "/a", Path: "/b"}}) + if err != nil { + t.Fatalf("move: %v", err) + } + if want := map[string]any{"b": float64(1)}; !reflect.DeepEqual(got, want) { + t.Errorf("move = %#v, want %#v", got, want) + } + + got, err = ApplyPatch(map[string]any{"a": 1}, JSONPatch{{Op: JSONPatchOpCopy, From: "/a", Path: "/b"}}) + if err != nil { + t.Fatalf("copy: %v", err) + } + if want := map[string]any{"a": float64(1), "b": float64(1)}; !reflect.DeepEqual(got, want) { + t.Errorf("copy = %#v, want %#v", got, want) + } +} + +func TestApplyPatch_ArrayInsertAndAppend(t *testing.T) { + // "-" appends. + got, err := ApplyPatch([]any{"a"}, JSONPatch{{Op: JSONPatchOpAdd, Path: "/-", Value: "b"}}) + if err != nil { + t.Fatalf("append: %v", err) + } + if want := []any{"a", "b"}; !reflect.DeepEqual(got, want) { + t.Errorf("append = %#v, want %#v", got, want) + } + + // numeric index inserts. + got, err = ApplyPatch([]any{"a", "c"}, JSONPatch{{Op: JSONPatchOpAdd, Path: "/1", Value: "b"}}) + if err != nil { + t.Fatalf("insert: %v", err) + } + if want := []any{"a", "b", "c"}; !reflect.DeepEqual(got, want) { + t.Errorf("insert = %#v, want %#v", got, want) + } +} + +func TestParsePointer(t *testing.T) { + tests := []struct { + in string + want []string + wantErr bool + }{ + {"", nil, false}, + {"/a", []string{"a"}, false}, + {"/a/b", []string{"a", "b"}, false}, + {"/a~1b", []string{"a/b"}, false}, + {"/m~0n", []string{"m~n"}, false}, + {"/~01", []string{"~1"}, false}, // ~01 decodes to ~1, not / + {"bad", nil, true}, + } + for _, tc := range tests { + got, err := parsePointer(tc.in) + if (err != nil) != tc.wantErr { + t.Errorf("parsePointer(%q) err = %v, wantErr %v", tc.in, err, tc.wantErr) + continue + } + if !tc.wantErr && !reflect.DeepEqual(got, tc.want) { + t.Errorf("parsePointer(%q) = %#v, want %#v", tc.in, got, tc.want) + } + } +} + +// patchString renders a patch compactly for test failure messages. +func patchString(p JSONPatch) string { + if len(p) == 0 { + return "[]" + } + out := "[" + for i, op := range p { + if i > 0 { + out += ", " + } + out += fmt.Sprintf("%s %s", op.Op, op.Path) + if op.From != "" { + out += " from=" + op.From + } + if op.Value != nil { + out += fmt.Sprintf(" value=%v", op.Value) + } + } + return out + "]" +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index bc8d3080d6..b2cd73073d 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -317,6 +317,12 @@ type Session[State any] struct { state SessionState[State] store SessionStore[State] version uint64 // incremented on every mutation; used to skip redundant snapshots + + // onCustomChange, when set by the agent runtime, is invoked after every + // UpdateCustom mutation (outside the lock) so the runtime can emit a + // customPatch chunk describing the delta. Nil for a standalone Session, + // in which case UpdateCustom is silent. + onCustomChange func() } // SessionID returns the ID of the session this conversation belongs to. @@ -392,15 +398,37 @@ func (s *Session[State]) Custom() State { return s.state.Custom } +// customJSON returns a deep, JSON-normalized copy (a map[string]any / []any / +// ... tree) of just the custom state, taken under the lock so it is safe to +// use after the lock is released. Unlike [Session.State] it does not copy the +// messages or artifacts, so the streaming patcher can diff custom on the hot +// path without re-serializing the whole conversation on every mutation. +func (s *Session[State]) customJSON() any { + s.mu.RLock() + defer s.mu.RUnlock() + return normalizeJSON(s.state.Custom) +} + // UpdateCustom atomically reads the current custom state, applies the given // function, and writes the result back. fn runs while the session's // internal lock is held: it must not call other Session methods or send // on a [Responder], or it will deadlock. +// +// When the session is driven by an agent invocation, the mutation is streamed +// to the client as an [AgentStreamChunk.CustomPatch] describing the delta (the +// runtime computes and emits it after fn returns). Agents therefore just mutate +// state; they never hand-craft patches. func (s *Session[State]) UpdateCustom(fn func(State) State) { s.mu.Lock() - defer s.mu.Unlock() s.state.Custom = fn(s.state.Custom) s.version++ + s.mu.Unlock() + // Emit the customPatch delta after releasing the lock: the hook reads + // session state (and may send on the wire), neither of which is safe to + // do while holding s.mu. + if s.onCustomChange != nil { + s.onCustomChange() + } } // Artifacts returns the current artifacts. The returned slice is a fresh diff --git a/go/core/schemas.config b/go/core/schemas.config index 8d14b3a386..3bde599227 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1437,12 +1437,72 @@ RuntimeError.details doc Details is optional structured information describing the failure. . +# ---------------------------------------------------------------------------- +# JsonPatchOp / JsonPatchOperation / JsonPatch +# +# The schema (Zod / JSON Schema) names keep the TS-conventional "Json" spelling +# shared across language runtimes; the `name` directive binds them to idiomatic +# Go "JSON" type names, just as the wire field "sessionId" becomes Go SessionID. +# ---------------------------------------------------------------------------- + +JsonPatchOp pkg ai/exp +JsonPatchOp name JSONPatchOp + +JsonPatchOp doc +JSONPatchOp is the kind of a JSON Patch operation (RFC 6902): one of +[JSONPatchOpAdd], [JSONPatchOpRemove], [JSONPatchOpReplace], [JSONPatchOpMove], +[JSONPatchOpCopy], or [JSONPatchOpTest]. +. + +JsonPatchOperation pkg ai/exp +JsonPatchOperation name JSONPatchOperation + +JsonPatchOperation doc +JSONPatchOperation is a single RFC 6902 (JSON Patch) operation. A [JSONPatch] +applies an ordered list of these to transform one JSON document into another. +. + +JsonPatchOperation.op noomitempty +JsonPatchOperation.op doc +Op is the operation to perform. +. + +JsonPatchOperation.path noomitempty +JsonPatchOperation.path doc +Path is a JSON Pointer (RFC 6901) to the target location, e.g. "/agentStatus". +The empty pointer "" refers to the whole document. It must always be present on +the wire (a whole-document replace carries path ""), so it is not omitted when +empty. +. + +JsonPatchOperation.from doc +From is a JSON Pointer to the source location; required for "move" and "copy". +. + +JsonPatchOperation.value type any +JsonPatchOperation.value noomitempty +JsonPatchOperation.value doc +Value is the operand for "add", "replace", and "test". It is not omitted when +null so an explicit null operand survives the wire (omitempty cannot tell a +null operand from an absent one, and dropping it makes a peer applier set the +member to undefined or remove it instead of null); for "remove", "move", and +"copy" it is null and ignored. +. + +JsonPatch pkg ai/exp +JsonPatch name JSONPatch + +JsonPatch doc +JSONPatch is an RFC 6902 JSON Patch: an ordered list of operations applied in +sequence. Use [Diff] to compute the patch between two values and [ApplyPatch] +to apply one to a document. +. + # ---------------------------------------------------------------------------- # AgentStreamChunk # ---------------------------------------------------------------------------- AgentStreamChunk pkg ai/exp -AgentStreamChunk typeparams [Stream any] AgentStreamChunk doc AgentStreamChunk represents a single item in the agent's output stream. @@ -1454,10 +1514,20 @@ AgentStreamChunk.modelChunk doc ModelChunk contains generation tokens from the model. . -AgentStreamChunk.status type Stream -AgentStreamChunk.status doc -Status contains user-defined structured status information. -The Stream type parameter defines the shape of this data. +AgentStreamChunk.customPatch type JSONPatch +AgentStreamChunk.customPatch doc +CustomPatch is an RFC 6902 JSON Patch describing a delta applied to the +session's custom state. The runtime emits it automatically whenever the +agent mutates custom state (e.g. via [Session.UpdateCustom]); agents do not +hand-craft patches. Pointers are rooted at the custom document (e.g. +"/agentStatus"), with no "/custom" prefix. The first patch of every turn is a +whole-document replace at the root pointer ("") that re-bases clients which +may not share the server's baseline; subsequent patches are incremental diffs +against the last sent value. The diff is computed on the client-facing custom +state (after any [WithStateTransform]), so streamed deltas honor redaction and +stay consistent with the full state in turn-end snapshots and final output. +Apply it with [ApplyPatch] to keep a local copy of custom live as the turn +streams. . AgentStreamChunk.artifact doc diff --git a/go/genkit/exp/routes.go b/go/genkit/exp/routes.go index 36c268e173..3878bc378b 100644 --- a/go/genkit/exp/routes.go +++ b/go/genkit/exp/routes.go @@ -99,7 +99,7 @@ func AllAgentRoutes(g *genkit.Genkit) []Route { // the same as the reflection API. Companion routes are omitted for // capabilities the agent lacks; a client-managed agent contributes only // its turn route. -func AgentRoutes[Stream, State any](a *aix.Agent[Stream, State]) []Route { +func AgentRoutes[State any](a *aix.Agent[State]) []Route { return buildAgentRoutes(a.Name(), a, a.GetSnapshotAction(), a.AbortSnapshotAction()) } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index df395a4bb4..f626148a12 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -470,7 +470,7 @@ func DefineAgent[State any]( name string, source aix.AgentSource, opts ...aix.AgentOption[State], -) *aix.Agent[any, State] { +) *aix.Agent[State] { return aix.DefineAgent(g.reg, name, source, opts...) } @@ -500,14 +500,14 @@ func DefineAgent[State any]( // - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types // - [aix.WithStateTransform]: Rewrite session state on its way out to the client // -// Type parameters: -// - Stream: Type for custom status updates sent via [aix.Responder.SendStatus] -// - State: Type for user-defined state persisted in snapshots +// The State type parameter is the shape of the conversation's custom state +// ([aix.SessionState.Custom]); mutating it via [aix.Session.UpdateCustom] +// streams an [aix.AgentStreamChunk.CustomPatch] delta to the client. // // Example: // // chatAgent := genkit.DefineCustomAgent(g, "chat", -// func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { +// func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { // var lastMessage *ai.Message // err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { // var reason aix.AgentFinishReason @@ -536,12 +536,12 @@ func DefineAgent[State any]( // return &aix.AgentResult{Message: lastMessage}, nil // }, // ) -func DefineCustomAgent[Stream, State any]( +func DefineCustomAgent[State any]( g *Genkit, name string, - fn aix.AgentFunc[Stream, State], + fn aix.AgentFunc[State], opts ...aix.AgentOption[State], -) *aix.Agent[Stream, State] { +) *aix.Agent[State] { return aix.DefineCustomAgent(g.reg, name, fn, opts...) } diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index c8a7defa6f..6f1648b6cc 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -56,7 +56,7 @@ var errQuit = errors.New("quit") // between two screens forever: the agent list and a per-agent chat. // Returning from a chat brings the user back to the agent list. /quit // (anywhere) and Ctrl-C both unwind back here and exit cleanly. -func runCLI(ctx context.Context, agents []*aix.Agent[any, any]) error { +func runCLI(ctx context.Context, agents []*aix.Agent[any]) error { fmt.Println("Genkit Basic Agents") fmt.Println("===================") fmt.Println() @@ -101,7 +101,7 @@ func runCLI(ctx context.Context, agents []*aix.Agent[any, any]) error { // pickAgent renders the agent list and reads the user's choice. The // list is re-rendered between selections so the user can see updated // pending/terminal status after returning from a chat. -func pickAgent(ctx context.Context, inputCh <-chan string, agents []*aix.Agent[any, any], lastSession map[string]string) (int, bool) { +func pickAgent(ctx context.Context, inputCh <-chan string, agents []*aix.Agent[any], lastSession map[string]string) (int, bool) { for { fmt.Println() fmt.Println("Agents:") @@ -141,7 +141,7 @@ func pickAgent(ctx context.Context, inputCh <-chan string, agents []*aix.Agent[a // so the rest of the flow is uniform: ok=false means the user backed // out, otherwise hand the chosen snapshot (or nil for fresh) to // runChat. -func openAgent(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], lastSessionID string) (string, error) { +func openAgent(ctx context.Context, inputCh <-chan string, a *aix.Agent[any], lastSessionID string) (string, error) { // Resolve where the last conversation left off. With no tracked // session (a first visit this run) there is nothing to resume; // otherwise the store resolves the session's latest snapshot, which @@ -189,7 +189,7 @@ func openAgent(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any // snapshot directly so the caller can skip the resume / new prompt: // the user already committed to the choice by waiting, and re-asking // would be redundant. -func handlePending(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], pending *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { +func handlePending(ctx context.Context, inputCh <-chan string, a *aix.Agent[any], pending *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { for { fmt.Printf("\nThe last %s session is still running in the background (%s).\n", a.Name(), shortID(pending.SnapshotID)) fmt.Println(" 1. Wait for it to finalize") @@ -239,7 +239,7 @@ func handlePending(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, // offers two paths so the demo stays focused: resume from the most // recent terminal snapshot (returns the snapshot pointer), or start // fresh (returns nil). -func pickSession(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], latest *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { +func pickSession(ctx context.Context, inputCh <-chan string, a *aix.Agent[any], latest *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { if latest == nil || latest.Status != aix.SnapshotStatusSucceeded { fmt.Printf("\nStarting a new conversation with %s.\n", a.Name()) return nil, true @@ -274,7 +274,7 @@ func pickSession(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, a // snapshot the user was just shown. Validating up front keeps the chat // from opening on a connection whose invocation already failed, which // would surface the error only after the user types a message. -func resumeOption(ctx context.Context, a *aix.Agent[any, any], resume *aix.SessionSnapshot[any]) aix.InvocationOption[any] { +func resumeOption(ctx context.Context, a *aix.Agent[any], resume *aix.SessionSnapshot[any]) aix.InvocationOption[any] { if resume.SessionID != "" { tip, err := a.Store().GetLatestSnapshot(ctx, resume.SessionID) if err == nil && tip != nil && tip.Status != aix.SnapshotStatusPending { @@ -293,7 +293,7 @@ func resumeOption(ctx context.Context, a *aix.Agent[any, any], resume *aix.Sessi // detaches the connection, leaving a pending snapshot for the user to // observe. It returns the session ID the chat ran under (falling back to // prevSessionID) so the caller can offer to resume it later. -func runChat(ctx context.Context, inputCh <-chan string, a *aix.Agent[any, any], resume *aix.SessionSnapshot[any], prevSessionID string) (string, error) { +func runChat(ctx context.Context, inputCh <-chan string, a *aix.Agent[any], resume *aix.SessionSnapshot[any], prevSessionID string) (string, error) { fmt.Printf("\n=== Chatting with %s ===\n", a.Name()) if resume != nil { fmt.Printf("Resumed from %s\n", shortID(resume.SnapshotID)) @@ -437,7 +437,7 @@ repl: // list: the tip of the conversation last run with it this session. Empty // when none has run yet, or the session has no resumable snapshot, so a // fresh list shows no clutter. -func summarizeLatest(ctx context.Context, a *aix.Agent[any, any], sessionID string) string { +func summarizeLatest(ctx context.Context, a *aix.Agent[any], sessionID string) string { if sessionID == "" { return "" } @@ -462,7 +462,7 @@ func summarizeLatest(ctx context.Context, a *aix.Agent[any, any], sessionID stri // The status subscription is the optional SnapshotAborter half of the // store contract. A store without it cannot stream background progress, // so we fall back to reading the snapshot once and returning it as-is. -func waitForFinalize(ctx context.Context, a *aix.Agent[any, any], snapshotID string) (*aix.SessionSnapshot[any], error) { +func waitForFinalize(ctx context.Context, a *aix.Agent[any], snapshotID string) (*aix.SessionSnapshot[any], error) { store := a.Store() aborter, ok := store.(aix.SnapshotAborter) if !ok { diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index e508c874db..5e9c7820a9 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -85,7 +85,7 @@ func main() { // and a.Store() for snapshot reads. Nothing the CLI does is tied to a // concrete store type, so swapping in a different SessionStore would // not touch a line of it. - agents := []*aix.Agent[any, any]{ + agents := []*aix.Agent[any]{ defineInlineAgent(g), definePromptAgent(g), defineCustomAgent(g), @@ -103,7 +103,7 @@ func main() { // prompt, appends the conversation history, calls the model, and updates // session state. This is the shortest path from "I want a chat agent" to // a working one. -func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any, any] { +func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any] { const name = "pirate" return genkit.DefineAgent(g, name, aix.FromInline( @@ -129,7 +129,7 @@ func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any, any] { // FromPrompt's argument is the default input passed to the prompt's // Render on every turn; the inline-prompt variant has no per-turn input // of its own. -func definePromptAgent(g *genkit.Genkit) *aix.Agent[any, any] { +func definePromptAgent(g *genkit.Genkit) *aix.Agent[any] { const name = "chef" return genkit.DefineAgent(g, name, aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), @@ -148,10 +148,10 @@ func definePromptAgent(g *genkit.Genkit) *aix.Agent[any, any] { // // Even with full control over the loop, the framework still owns session // state, snapshot writes, and the detach lifecycle. -func defineCustomAgent(g *genkit.Genkit) *aix.Agent[any, any] { +func defineCustomAgent(g *genkit.Genkit) *aix.Agent[any] { const name = "coder" return genkit.DefineCustomAgent(g, name, - func(ctx context.Context, resp aix.Responder[any], sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { for chunk, err := range genkit.GenerateStream(ctx, g, ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ From 417f6d0f49afee4a0c15abff9c13b0860b8d4cf7 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 16:38:05 -0700 Subject: [PATCH 107/141] feat(go/exp): rework agent session and snapshot flow Reworks how server-managed agents resolve sessions and snapshots, and what they advertise about their custom state. All in the experimental exp package. - AgentInit with only a sessionId no longer fails with NOT_FOUND when the session has no snapshots. The caller is starting a brand-new conversation under an ID of its own choosing; the runtime adopts that ID for the whole session lifecycle (every snapshot carries it) instead of minting one. - GetSnapshotRequest gains a sessionId field, not mutually exclusive with snapshotId. sessionId alone fetches the session's latest snapshot; with both, the fetched snapshot must belong to that session. GetLatestSnapshot now returns the literal latest row whatever its status, so a reconnecting client can observe a pending, failed, or aborted tip. Resume-by-sessionId consequently rejects a dead-end tip rather than skipping it; to continue past one, pass an explicit snapshotId. - AgentMetadata gains stateSchema: the JSON schema of the agent's custom state, inferred from the State type like action input/output schemas. - SnapshotStatus "succeeded" is renamed to "completed" (wire value and the SnapshotStatusCompleted constant). This is a wire change, allowed here because exp APIs may change in any minor release. The Zod schema source (agent.ts) and the generated gen.go are regenerated; schemas.config carries the matching doc updates. --- genkit-tools/common/src/types/agent.ts | 67 +++-- genkit-tools/genkit-schema.json | 12 +- go/ai/exp/agent.go | 68 ++++-- go/ai/exp/agent_test.go | 323 ++++++++++++++++++++----- go/ai/exp/gen.go | 65 +++-- go/ai/exp/localstore/file.go | 18 +- go/ai/exp/localstore/file_test.go | 32 +-- go/ai/exp/localstore/inmemory.go | 15 +- go/ai/exp/localstore/inmemory_test.go | 14 +- go/ai/exp/localstore/store_test.go | 45 ++-- go/ai/exp/session.go | 84 ++++--- go/ai/exp/teststore_test.go | 5 +- go/core/schemas.config | 69 ++++-- go/genkit/servers_test.go | 14 +- go/samples/basic-agents/cli.go | 8 +- 15 files changed, 586 insertions(+), 253 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 0289cb3d6e..a7107e672b 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -49,7 +49,7 @@ export type Artifact = z.infer; * - `recovery`: snapshot was written retroactively by the failure path to * preserve the last-good state (everything through the last successful * turn) when a selective snapshot callback had skipped persisting it. - * It is a normal `succeeded` row carrying the last good turn's + * It is a normal `completed` row carrying the last good turn's * `finishReason`, resumable like any other; the callback is bypassed. */ export const SnapshotEventSchema = z.enum([ @@ -67,7 +67,7 @@ export type SnapshotEvent = z.infer; * The snapshot's state is empty until the background work finishes, at * which point it is rewritten with the cumulative final state and a * terminal status. - * - `succeeded`: the snapshot captures a settled state. + * - `completed`: the snapshot captures a settled state. * - `aborted`: the snapshot's invocation was aborted via the * `abortSnapshot` companion action while detached. * - `failed`: the invocation terminated with an error. The snapshot's `error` @@ -75,7 +75,7 @@ export type SnapshotEvent = z.infer; */ export const SnapshotStatusSchema = z.enum([ 'pending', - 'succeeded', + 'completed', 'aborted', 'failed', ]); @@ -170,18 +170,19 @@ export type AgentInput = z.infer; */ export const AgentInitSchema = z.object({ /** - * Identifies the session (conversation) to resume. Only valid when the - * agent is server-managed (a session store is configured); mutually - * exclusive with state (a client-managed conversation carries its - * identity inside `state.sessionId`). Alone it resumes the session - * from its latest snapshot: the most recently updated one that is not - * a failed/aborted dead end. A pending latest snapshot (a detached - * invocation still running) rejects the resume rather than racing the - * background work; if the session's history was forked by re-resuming - * an earlier snapshot, the most recently updated branch wins, and - * snapshotId can pick a branch explicitly. Combined with snapshotId, - * it asserts which session the snapshot belongs to, and a mismatch is - * rejected. + * Identifies the session (conversation) to resume or start. Only valid + * when the agent is server-managed (a session store is configured); + * mutually exclusive with state (a client-managed conversation carries + * its identity inside `state.sessionId`). Alone it resumes the session + * from its latest snapshot: the most recently updated row, whatever its + * status. If that row is a failed, aborted, or still-pending dead end + * the resume is rejected (pass snapshotId to continue from a specific + * earlier point); if the session's history was forked by re-resuming an + * earlier snapshot, the most recently updated branch wins. If the + * session has no snapshots yet, a brand-new conversation is started + * under this caller-chosen ID, and every snapshot it persists carries + * it. Combined with snapshotId, it asserts which session the snapshot + * belongs to, and a mismatch is rejected. */ sessionId: z.string().optional(), /** @@ -225,8 +226,9 @@ export const AgentOutputSchema = z.object({ /** * ID of the session this invocation belongs to, assigned by the * framework when the invocation starts. With server-managed state, a - * fresh invocation mints a new ID, resumed invocations inherit the - * chain's, and resuming a snapshot from before session IDs existed + * fresh invocation adopts the caller-supplied session ID (see + * AgentInit.sessionId) or mints a new one, resumed invocations inherit + * the chain's, and resuming a snapshot from before session IDs existed * mints a fresh one. With client-managed state it echoes the ID * carried inside the state object (`state.sessionId`), minting one on * the conversation's first invocation; only a session with persisted @@ -391,7 +393,7 @@ export const SessionSnapshotSchema = z.object({ updatedAt: z.string().optional(), /** What triggered this snapshot. */ event: SnapshotEventSchema, - /** Lifecycle state of this snapshot. Empty is treated as `succeeded`. */ + /** Lifecycle state of this snapshot. Empty is treated as `completed`. */ status: SnapshotStatusSchema.optional(), /** * Semantic reason the turn or invocation captured here ended (e.g. @@ -417,10 +419,27 @@ export type SessionSnapshot = z.infer; * `agent-snapshot`) when the agent has a session store configured. The * action returns the stored `SessionSnapshot`, with any configured state * transform applied to its state. + * + * At least one of `snapshotId` or `sessionId` must be set; they are not + * mutually exclusive. `snapshotId` fetches a specific snapshot; + * `sessionId` alone fetches the session's latest snapshot (via the + * store's `GetLatestSnapshot`, whatever its status). When both are set + * the fetched snapshot must belong to that session, or the request is + * rejected. */ export const GetSnapshotRequestSchema = z.object({ - /** Identifies the snapshot to fetch. */ - snapshotId: z.string(), + /** + * Identifies the snapshot to fetch. Optional when `sessionId` is given; + * when both are present the fetched snapshot must belong to that session. + */ + snapshotId: z.string().optional(), + /** + * Identifies the session whose latest snapshot to fetch. Optional when + * `snapshotId` is given. The latest snapshot is the session's most + * recently updated row regardless of status (pending, failed, or + * aborted included). + */ + sessionId: z.string().optional(), }); export type GetSnapshotRequest = z.infer; @@ -474,5 +493,13 @@ export const AgentMetadataSchema = z.object({ * configured store implements the abort lifecycle. */ abortable: z.boolean(), + /** + * JSON schema for the agent's custom session state (the `custom` field + * of `SessionState`), inferred from the agent's state type. Lets the + * Dev UI and other reflective callers render or validate state without + * the agent describing it separately. Omitted when the state type + * carries no schema to infer (e.g. an unstructured `any` state). + */ + stateSchema: z.record(z.any()).optional(), }); export type AgentMetadata = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index c1d689caf5..f13ebcf0b9 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -95,6 +95,10 @@ }, "abortable": { "type": "boolean" + }, + "stateSchema": { + "type": "object", + "additionalProperties": {} } }, "required": [ @@ -203,11 +207,11 @@ "properties": { "snapshotId": { "type": "string" + }, + "sessionId": { + "type": "string" } }, - "required": [ - "snapshotId" - ], "additionalProperties": false }, "JsonPatchOp": { @@ -323,7 +327,7 @@ "type": "string", "enum": [ "pending", - "succeeded", + "completed", "aborted", "failed" ] diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 5b9e7c6cdb..b8f59bfd03 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -27,6 +27,7 @@ import ( "fmt" "iter" "maps" + "reflect" "runtime/debug" "sync" "sync/atomic" @@ -327,7 +328,7 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot return s.persistSnapshotLocked(ctx, event, finishReason, ¤tState, currentVersion) } -// persistSnapshotLocked writes a succeeded snapshot row capturing state (at +// persistSnapshotLocked writes a completed snapshot row capturing state (at // the given session version), chained to the newest persisted snapshot, and // advances the lastSnapshot bookkeeping. Both the routine cadence // (maybeSnapshot) and the failure path (recoverySnapshotID) funnel through @@ -344,7 +345,7 @@ func (s *SessionRunner[State]) persistSnapshotLocked(ctx context.Context, event SessionID: sessionID, ParentID: parentID, Event: event, - Status: SnapshotStatusSucceeded, + Status: SnapshotStatusCompleted, FinishReason: finishReason, State: state, }, nil @@ -743,7 +744,8 @@ func DefineCustomAgent[State any]( // agentMetadataFor derives the [AgentMetadata] value attached to the // agent's action descriptor under the "agent" key. [AgentMetadata] // itself is generated from agent.ts; this constructor is hand-written -// because it inspects the configured store's optional capabilities. +// because it inspects the configured store's optional capabilities and +// infers the custom-state schema from the State type parameter. func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { mgmt := AgentStateManagementClient abortable := false @@ -754,9 +756,23 @@ func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { return AgentMetadata{ StateManagement: mgmt, Abortable: abortable, + StateSchema: stateSchemaFor[State](), } } +// stateSchemaFor infers the JSON schema for an agent's custom state type, +// the same way core derives an action's input/output schemas. It returns +// nil for an interface State (e.g. Agent[any]), whose zero value carries +// no type information to infer from, so [AgentMetadata.StateSchema] is +// simply omitted rather than advertising an empty object. +func stateSchemaFor[State any]() map[string]any { + var zero State + if reflect.ValueOf(&zero).Elem().Kind() == reflect.Interface { + return nil + } + return core.InferSchemaMap(zero) +} + // --- agentRuntime --- // agentRuntime owns the per-invocation wiring of an agent: @@ -803,14 +819,23 @@ func newAgentRuntime[State any]( // state goes: every persisted snapshot's state, and the state // returned to (and resent by) client-managed callers. if cfg.store != nil { - // Server-managed: the store row is canonical. Inherit the resumed - // chain's ID, overriding whatever the loaded state blob claims (a - // third-party writer could have let them drift), or mint one for a - // fresh conversation (including one resumed from a snapshot that - // predates session IDs). - if parent != nil && parent.SessionID != "" { + // Server-managed: the store row is canonical, overriding whatever the + // loaded state blob claims (a third-party writer could have let them + // drift). + switch { + case parent != nil && parent.SessionID != "": + // Resumed an existing chain: inherit its ID. session.state.SessionID = parent.SessionID - } else { + case parent == nil && in != nil && in.SessionID != "": + // No snapshot resolved for the caller-supplied session ID: the + // client is starting a brand-new conversation under an ID of its + // own choosing. Honor it for the whole session lifecycle so every + // snapshot it persists carries it (rather than minting a server ID + // and stranding the client's ID). + session.state.SessionID = in.SessionID + default: + // Fresh conversation, or one resumed from a snapshot that predates + // session IDs: mint one. session.state.SessionID = uuid.New().String() } } else if session.state.SessionID == "" { @@ -1200,7 +1225,7 @@ func (rt *agentRuntime[State]) finalizePendingSnapshot( // Captured outside the SaveSnapshot callback (which must stay pure): the // finalizer runs after fn returned, so this is stable. The abort/error // branches below own their reasons and ignore this clean-success default. - succeededReason := rt.sess.invocationReason(result) + completedReason := rt.sess.invocationReason(result) _, err := rt.cfg.store.SaveSnapshot(ctx, pending.SnapshotID, func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error) { @@ -1218,11 +1243,11 @@ func (rt *agentRuntime[State]) finalizePendingSnapshot( return &annotated, nil } - status := SnapshotStatusSucceeded + status := SnapshotStatusCompleted // The persisted finish reason records how the background work // actually ended, distinct from the detached reason the client // already saw on AgentOutput. - finishReason := succeededReason + finishReason := completedReason var snapErr *core.GenkitError switch { case abortedByUser: @@ -1323,7 +1348,12 @@ func loadSession[State any]( return nil, nil, core.NewError(core.INTERNAL, "failed to resolve latest snapshot for session %q: %v", init.SessionID, err) } if snap == nil { - return nil, nil, core.NewError(core.NOT_FOUND, "no resumable snapshot found for session %q", init.SessionID) + // No snapshot exists for this session ID yet: the caller is + // starting a brand-new conversation under an ID of its own + // choosing, not resuming one. Return a fresh session with no + // parent; newAgentRuntime stamps the chosen ID so every snapshot + // the new session persists carries it. + return s, nil, nil } if snap.SessionID != init.SessionID { return nil, nil, core.NewError(core.INTERNAL, @@ -1335,10 +1365,12 @@ func loadSession[State any]( } // resumeSessionFrom validates that snap is in a resumable status and loads -// its state into s. Shared by the snapshot-ID and session-ID init paths; -// the session-ID path can only hit the pending case (a conforming store's -// GetLatestSnapshot never resolves to failed/aborted dead ends), but the -// full switch stays as a defense against non-conforming stores. +// its state into s. Shared by the snapshot-ID and session-ID init paths: +// both reject a failed, aborted, or pending snapshot, since none can be +// continued from. The session-ID path reaches them too, because +// GetLatestSnapshot returns the literal latest row whatever its status; a +// caller wanting to continue past a dead-end tip must name an earlier good +// snapshot explicitly via SnapshotID. func resumeSessionFrom[State any](s *Session[State], snap *SessionSnapshot[State]) (*Session[State], *SessionSnapshot[State], error) { switch snap.Status { case SnapshotStatusFailed: diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 74381c35e3..393bd4b165 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -995,8 +995,8 @@ func TestAgent_FailedTurn_RecoverySnapshotBypassesCallback(t *testing.T) { if err != nil || snap == nil { t.Fatalf("GetSnapshot(%q): %v, %v", out.SnapshotID, snap, err) } - if snap.Status != SnapshotStatusSucceeded { - t.Errorf("expected recovery snapshot status %q, got %q", SnapshotStatusSucceeded, snap.Status) + if snap.Status != SnapshotStatusCompleted { + t.Errorf("expected recovery snapshot status %q, got %q", SnapshotStatusCompleted, snap.Status) } if snap.Event != SnapshotEventRecovery { t.Errorf("expected recovery snapshot event %q, got %q", SnapshotEventRecovery, snap.Event) @@ -1077,7 +1077,7 @@ func TestAgent_FailedFirstTurn_AfterResume_ReturnsParentSnapshotID(t *testing.T) func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ Event: SnapshotEventInvocationEnd, - Status: SnapshotStatusSucceeded, + Status: SnapshotStatusCompleted, State: &SessionState[testState]{ Messages: []*ai.Message{ ai.NewUserTextMessage("one"), @@ -2626,7 +2626,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { close(release) final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) if final.State.Custom.Counter != 2 { t.Errorf("final counter = %d, want 2 (A + D both processed)", final.State.Custom.Counter) @@ -2717,7 +2717,7 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { // Release remaining turns and let finalize run. close(release) waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) } @@ -2758,7 +2758,7 @@ func TestAgent_Detach_RequiresStore(t *testing.T) { func TestAgent_Detach_PendingThenComplete(t *testing.T) { // Client detaches mid-flow; flow finishes naturally; pending snapshot - // flips to status=succeeded with the full session state. + // flips to status=completed with the full session state. reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -2839,7 +2839,7 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { close(release) finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) if finalSnap.State.Custom.Counter != 42 { t.Errorf("expected counter=42 in final snapshot, got %d", finalSnap.State.Custom.Counter) @@ -2914,7 +2914,7 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { close(release) final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) names := make(map[string]bool, len(final.State.Artifacts)) @@ -3119,8 +3119,8 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { if err != nil { t.Fatalf("GetSnapshot: %v", err) } - if snap.Status != SnapshotStatusSucceeded { - t.Errorf("turn-end snapshot status = %q, want succeeded", snap.Status) + if snap.Status != SnapshotStatusCompleted { + t.Errorf("turn-end snapshot status = %q, want completed", snap.Status) } if snap.Event != SnapshotEventTurnEnd { t.Errorf("turn-end snapshot event = %q, want %q", snap.Event, SnapshotEventTurnEnd) @@ -3272,8 +3272,8 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { if resp.SnapshotID != out.SnapshotID { t.Errorf("SnapshotID mismatch: got %q want %q", resp.SnapshotID, out.SnapshotID) } - if resp.Status != SnapshotStatusSucceeded { - t.Errorf("expected status=succeeded, got %q", resp.Status) + if resp.Status != SnapshotStatusCompleted { + t.Errorf("expected status=completed, got %q", resp.Status) } if resp.State == nil { t.Fatal("expected state in response") @@ -3349,6 +3349,95 @@ func TestAgent_GetSnapshotAction_ReturnsFinishReason(t *testing.T) { } } +// TestAgent_GetSnapshotAction_BySessionID verifies the getSnapshot companion +// action's session-ID modes: fetching the session's latest snapshot (whatever +// its status, the way a reconnecting client observes a session), composing +// with a snapshot ID as an integrity assertion, and the argument-validation +// edges. +func TestAgent_GetSnapshotAction_BySessionID(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "getBySessionFlow", WithSessionStore(store)) + + ctx := context.Background() + out1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("RunText: %v", err) + } + out2, err := af.RunText(ctx, "second", WithSessionID[testState](out1.SessionID)) + if err != nil { + t.Fatalf("RunText resume: %v", err) + } + + action := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}]( + reg, api.ActionTypeAgentSnapshot, "getBySessionFlow") + if action == nil { + t.Fatal("getSnapshot action not registered") + } + + // sessionId alone resolves the session's latest snapshot (the second turn). + resp, err := action.Run(ctx, &GetSnapshotRequest{SessionID: out1.SessionID}, nil) + if err != nil { + t.Fatalf("getSnapshot by sessionId: %v", err) + } + if resp.SnapshotID != out2.SnapshotID { + t.Errorf("latest snapshot = %q, want most recent %q", resp.SnapshotID, out2.SnapshotID) + } + + // sessionId + matching snapshotId returns that exact snapshot. + resp, err = action.Run(ctx, &GetSnapshotRequest{SessionID: out1.SessionID, SnapshotID: out1.SnapshotID}, nil) + if err != nil { + t.Fatalf("getSnapshot by snapshotId+sessionId: %v", err) + } + if resp.SnapshotID != out1.SnapshotID { + t.Errorf("snapshot = %q, want %q", resp.SnapshotID, out1.SnapshotID) + } + + // A snapshotId whose session does not match the asserted sessionId is rejected. + if _, err := action.Run(ctx, &GetSnapshotRequest{SessionID: "other-session", SnapshotID: out1.SnapshotID}, nil); err == nil { + t.Fatal("expected snapshot/session mismatch to be rejected") + } else if ge := core.AsGenkitError(err); ge.Status != core.INVALID_ARGUMENT { + t.Errorf("mismatch status = %q, want INVALID_ARGUMENT (err: %v)", ge.Status, err) + } + + // Neither field set is rejected. + if _, err := action.Run(ctx, &GetSnapshotRequest{}, nil); err == nil { + t.Fatal("expected empty request to be rejected") + } else if ge := core.AsGenkitError(err); ge.Status != core.INVALID_ARGUMENT { + t.Errorf("empty-request status = %q, want INVALID_ARGUMENT (err: %v)", ge.Status, err) + } + + // An unknown session resolves to no snapshot (NOT_FOUND). + if _, err := action.Run(ctx, &GetSnapshotRequest{SessionID: "no-such-session"}, nil); err == nil { + t.Fatal("expected NOT_FOUND for unknown session") + } else if ge := core.AsGenkitError(err); ge.Status != core.NOT_FOUND { + t.Errorf("unknown-session status = %q, want NOT_FOUND (err: %v)", ge.Status, err) + } + + // The session-ID lookup returns the latest row whatever its status, so a + // reconnecting client can observe a failed/pending tip (unlike resume, + // which rejects it). + failed, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + SessionID: out1.SessionID, + ParentID: out2.SnapshotID, + Event: SnapshotEventDetach, + Status: SnapshotStatusFailed, + FinishReason: AgentFinishReasonFailed, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot failed tip: %v", err) + } + resp, err = action.Run(ctx, &GetSnapshotRequest{SessionID: out1.SessionID}, nil) + if err != nil { + t.Fatalf("getSnapshot by sessionId (failed tip): %v", err) + } + if resp.SnapshotID != failed.SnapshotID || resp.Status != SnapshotStatusFailed { + t.Errorf("latest = %q/%q, want failed tip %q/failed", resp.SnapshotID, resp.Status, failed.SnapshotID) + } +} + func TestInMemorySessionStore_GetSnapshot_NotFound(t *testing.T) { store := newTestInMemStore[testState]() @@ -3457,12 +3546,6 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { store: nil, wantErr: "client-managed state", }, - { - name: "sessionId with no matching snapshots", - init: &AgentInit[testState]{SessionID: "sess-unknown"}, - store: store, - wantErr: "no resumable snapshot", - }, { name: "sessionId mismatching the loaded snapshot", init: &AgentInit[testState]{SessionID: "sess-other", SnapshotID: saved.SnapshotID}, @@ -3508,6 +3591,16 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { store: store, wantSnap: true, }, + { + // A session ID with no persisted snapshots is not an error: the + // caller is starting a brand-new conversation under an ID of its + // own choosing. loadSession returns a fresh session and no parent + // snapshot; the runtime stamps the chosen ID. + name: "sessionId with no matching snapshots starts fresh", + init: &AgentInit[testState]{SessionID: "sess-unknown"}, + store: store, + wantSnap: false, + }, } for _, tc := range okCases { @@ -3647,6 +3740,59 @@ func TestAgent_AgentMetadata(t *testing.T) { } } +func TestAgent_AgentMetadata_StateSchema(t *testing.T) { + // metadata["agent"].stateSchema advertises a JSON schema for the agent's + // custom state type, inferred the same way action input/output schemas + // are: a struct state yields an object schema with its fields; an + // unstructured `any` state yields none (omitted). + readMeta := func(t *testing.T, reg api.Registry, flowName string) AgentMetadata { + t.Helper() + act := reg.LookupAction(api.NewKey(api.ActionTypeAgent, "", flowName)) + if act == nil { + t.Fatal("agent action not registered") + } + meta, ok := act.Desc().Metadata["agent"].(AgentMetadata) + if !ok { + t.Fatalf("metadata[\"agent\"] missing or wrong type: %+v", act.Desc().Metadata["agent"]) + } + return meta + } + + t.Run("struct state yields an object schema with its fields", func(t *testing.T) { + reg := newTestRegistry(t) + DefineCustomAgent(reg, "structStateFlow", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, nil + }) + meta := readMeta(t, reg, "structStateFlow") + if meta.StateSchema == nil { + t.Fatal("expected a state schema for a struct state type") + } + if got := meta.StateSchema["type"]; got != "object" { + t.Errorf("state schema type = %v, want object", got) + } + props, ok := meta.StateSchema["properties"].(map[string]any) + if !ok { + t.Fatalf("state schema properties = %T, want map", meta.StateSchema["properties"]) + } + if _, ok := props["counter"]; !ok { + t.Errorf("state schema missing the 'counter' field: %+v", props) + } + }) + + t.Run("any state yields no schema", func(t *testing.T) { + reg := newTestRegistry(t) + DefineCustomAgent[any](reg, "anyStateFlow", + func(ctx context.Context, resp Responder, sess *SessionRunner[any]) (*AgentResult, error) { + return nil, nil + }) + meta := readMeta(t, reg, "anyStateFlow") + if meta.StateSchema != nil { + t.Errorf("expected no state schema for an any state type, got %+v", meta.StateSchema) + } + }) +} + func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { // Verify the abort companion action is only registered when the // store implements SnapshotAborter. The getSnapshot action is @@ -4004,7 +4150,7 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { t.Fatalf("Output: %v", err) } finalSnap := waitForSnapshot(t, store, first.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) if finalSnap.State.Custom.Counter != 1 { t.Fatalf("expected counter=1 in finalized snapshot, got %d", finalSnap.State.Custom.Counter) @@ -4083,7 +4229,7 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ Event: SnapshotEventTurnEnd, - Status: SnapshotStatusSucceeded, + Status: SnapshotStatusCompleted, }, nil }); err != nil { t.Fatalf("SaveSnapshot: %v", err) @@ -4092,8 +4238,8 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { if err != nil { t.Fatalf("AbortSnapshot on complete: %v", err) } - if status3 != SnapshotStatusSucceeded { - t.Errorf("abort on complete returned status=%q, want succeeded", status3) + if status3 != SnapshotStatusCompleted { + t.Errorf("abort on complete returned status=%q, want completed", status3) } } @@ -4118,7 +4264,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { } <-fnRelease // Return cleanly without observing ctx. Without the - // subscriber/recheck, this would land status=succeeded and + // subscriber/recheck, this would land status=completed and // clobber the abort. return nil, nil }) @@ -4159,7 +4305,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { close(fnRelease) finalSnap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusAborted || s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusAborted || s.Status == SnapshotStatusCompleted }) if finalSnap.Status != SnapshotStatusAborted { t.Errorf("finalize clobbered aborted with %q", finalSnap.Status) @@ -4261,8 +4407,8 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if status != SnapshotStatusSucceeded { - t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusSucceeded, status) + if status != SnapshotStatusCompleted { + t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusCompleted, status) } // Confirm the store snapshot was not flipped. @@ -4270,8 +4416,8 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { if err != nil { t.Fatalf("GetSnapshot: %v", err) } - if snap.Status != SnapshotStatusSucceeded { - t.Errorf("snapshot status = %q after abort-on-terminal, want succeeded", snap.Status) + if snap.Status != SnapshotStatusCompleted { + t.Errorf("snapshot status = %q after abort-on-terminal, want completed", snap.Status) } } @@ -4657,7 +4803,7 @@ func TestAgent_Detach_BackgroundWorkSurvivesActionReturn(t *testing.T) { } // Wait out the finalizer so the iteration's goroutines wind down. waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) } } @@ -4665,7 +4811,7 @@ func TestAgent_Detach_BackgroundWorkSurvivesActionReturn(t *testing.T) { // TestAgent_Detach_FinishReasons covers the three detach outcomes: the output // returned to the detaching client always reports "detached", while the // persisted snapshot records how the background work actually ended -// (succeeded -> last turn's reason, failed, or aborted). +// (completed -> last turn's reason, failed, or aborted). func TestAgent_Detach_FinishReasons(t *testing.T) { t.Run("complete", func(t *testing.T) { reg := newTestRegistry(t) @@ -4717,7 +4863,7 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { close(release) snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) if snap.FinishReason != AgentFinishReasonStop { t.Errorf("finalized snapshot.FinishReason = %q, want %q", snap.FinishReason, AgentFinishReasonStop) @@ -5048,11 +5194,11 @@ func TestPromptAgent_ForwardsInterruptedFinishReason(t *testing.T) { } } -// TestAgent_Detach_SucceededHonorsResultOverride verifies the detach finalizer +// TestAgent_Detach_CompletedHonorsResultOverride verifies the detach finalizer // applies an AgentResult.FinishReason override on clean success, matching the // synchronous path (the override does not leak into the failed/aborted cases, // which are covered by TestAgent_Detach_FinishReasons). -func TestAgent_Detach_SucceededHonorsResultOverride(t *testing.T) { +func TestAgent_Detach_CompletedHonorsResultOverride(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() release := make(chan struct{}) @@ -5107,7 +5253,7 @@ func TestAgent_Detach_SucceededHonorsResultOverride(t *testing.T) { close(release) snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) if snap.FinishReason != AgentFinishReasonOther { t.Errorf("finalized snapshot.FinishReason = %q, want %q (AgentResult override)", snap.FinishReason, AgentFinishReasonOther) @@ -5214,14 +5360,19 @@ func TestAgent_SessionID_AssignedBeforeFirstSnapshot(t *testing.T) { t.Errorf("expected no snapshot (callback declined every write), got %q", out.SnapshotID) } - // A session with no persisted snapshots cannot be resumed by its ID; - // the init-level rejection fails the action. + // A session with no persisted snapshots is not resumable, but supplying + // its ID is not an error: it starts a brand-new conversation under that + // caller-chosen ID rather than failing. The supplied ID carries through + // to the output (see also TestAgent_ResumeFromSessionID_NewConversation). out2, err := af.RunText(ctx, "again", WithSessionID[testState](out.SessionID)) - if err == nil { - t.Fatalf("expected NOT_FOUND error for snapshot-less session, got output: %+v", out2) + if err != nil { + t.Fatalf("RunText with session ID for a snapshot-less session: %v", err) + } + if out2.FinishReason == AgentFinishReasonFailed { + t.Fatalf("second invocation failed: %+v", out2.Error) } - if ge := core.AsGenkitError(err); ge.Status != core.NOT_FOUND { - t.Fatalf("expected NOT_FOUND, got %q (err: %v)", ge.Status, err) + if out2.SessionID != out.SessionID { + t.Errorf("second invocation session ID = %q, want the caller-supplied %q", out2.SessionID, out.SessionID) } } @@ -5341,10 +5492,11 @@ func TestAgent_ResumeFromSessionID_ForkContinuesLatestBranch(t *testing.T) { } } -func TestAgent_ResumeFromSessionID_SkipsDeadEnds(t *testing.T) { - // A failed (or aborted) row is a permanent dead end: even as the - // session's newest row it never blocks session-ID init. The session - // resumes from the last good snapshot. +func TestAgent_ResumeFromSessionID_FailedTipRejected(t *testing.T) { + // GetLatestSnapshot returns the session's literal latest row, so a failed + // (or aborted) tip is no longer skipped: resuming the session by ID hits + // the dead end and is rejected. To continue past it the caller names an + // earlier good snapshot via WithSnapshotID. ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -5368,37 +5520,84 @@ func TestAgent_ResumeFromSessionID_SkipsDeadEnds(t *testing.T) { t.Fatalf("SaveSnapshot failed row: %v", err) } - out2, err := af.RunText(ctx, "second", WithSessionID[testState](out1.SessionID)) + // Resuming by session ID hits the failed tip and is rejected. + if _, err := af.RunText(ctx, "second", WithSessionID[testState](out1.SessionID)); err == nil { + t.Fatal("expected resume to be rejected for a failed tip, got nil") + } else if ge := core.AsGenkitError(err); ge.Status != core.FAILED_PRECONDITION { + t.Fatalf("expected FAILED_PRECONDITION, got %q (err: %v)", ge.Status, err) + } + + // Naming the last good snapshot explicitly still resumes past the dead end. + out3, err := af.RunText(ctx, "third", WithSnapshotID[testState](out1.SnapshotID)) if err != nil { - t.Fatalf("RunText session resume: %v", err) + t.Fatalf("RunText resume from good snapshot: %v", err) } - if out2.FinishReason == AgentFinishReasonFailed { - t.Fatalf("session resume failed: %+v", out2.Error) + if out3.FinishReason == AgentFinishReasonFailed { + t.Fatalf("snapshot resume failed: %+v", out3.Error) } - snap2, err := store.GetSnapshot(ctx, out2.SnapshotID) + snap3, err := store.GetSnapshot(ctx, out3.SnapshotID) if err != nil { t.Fatalf("GetSnapshot: %v", err) } - if snap2.ParentID != out1.SnapshotID { - t.Errorf("resumed snapshot ParentID = %q, want last good %q", snap2.ParentID, out1.SnapshotID) - } - if got := snap2.State.Custom.Counter; got != 2 { + if got := snap3.State.Custom.Counter; got != 2 { t.Errorf("expected counter=2 (resumed from last good state), got %d", got) } +} + +func TestAgent_ResumeFromSessionID_NewConversation(t *testing.T) { + // A caller may name a brand-new session: with no snapshot yet under that + // ID, the invocation starts a fresh conversation and adopts the ID for + // the whole session lifecycle (every snapshot carries it) rather than + // failing with NOT_FOUND or minting a server ID. + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + af := defineLastGoodTestAgent(reg, "sessionNewConvoFlow", WithSessionStore(store)) - // The dead end keeps being skipped on subsequent resumes. - out3, err := af.RunText(ctx, "third", WithSessionID[testState](out1.SessionID)) + const sessID = "client-chosen-session" + out, err := af.RunText(ctx, "hello", WithSessionID[testState](sessID)) if err != nil { - t.Fatalf("RunText second session resume: %v", err) + t.Fatalf("RunText: %v", err) } - if out3.FinishReason == AgentFinishReasonFailed { - t.Fatalf("second session resume failed: %+v", out3.Error) + if out.FinishReason == AgentFinishReasonFailed { + t.Fatalf("invocation failed: %+v", out.Error) + } + if out.SessionID != sessID { + t.Errorf("output session ID = %q, want caller-supplied %q", out.SessionID, sessID) + } + // The persisted snapshot carries the caller-chosen session ID. + snap, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap.SessionID != sessID { + t.Errorf("snapshot session ID = %q, want %q", snap.SessionID, sessID) + } + + // A second invocation under the same ID now resumes the conversation it + // just created (that snapshot is the session's resumable tip). + out2, err := af.RunText(ctx, "again", WithSessionID[testState](sessID)) + if err != nil { + t.Fatalf("RunText resume: %v", err) + } + if out2.SessionID != sessID { + t.Errorf("resumed session ID = %q, want %q", out2.SessionID, sessID) + } + snap2, err := store.GetSnapshot(ctx, out2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if snap2.ParentID != out.SnapshotID { + t.Errorf("resumed snapshot ParentID = %q, want first snapshot %q", snap2.ParentID, out.SnapshotID) + } + if got := snap2.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2 after resuming the new conversation, got %d", got) } } func TestAgent_ResumeFromSessionID_AfterFailureResumesRecovery(t *testing.T) { // After an invocation fails, the session's newest non-dead-end row is - // the recovery snapshot (a normal succeeded row, event=recovery) + // the recovery snapshot (a normal completed row, event=recovery) // holding the last-good state. Resuming by session ID continues from // it like any other snapshot. ctx := context.Background() @@ -5691,7 +5890,7 @@ func TestAgent_Detach_AssignsSessionID(t *testing.T) { close(release) final := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) if final.SessionID != out.SessionID { t.Errorf("finalized row SessionID = %q, want %q", final.SessionID, out.SessionID) @@ -5767,7 +5966,7 @@ func TestAgent_Detach_WaitsForInFlightTurnSnapshot(t *testing.T) { } final := waitForSnapshot[testState](t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { - return s.Status == SnapshotStatusSucceeded + return s.Status == SnapshotStatusCompleted }) // Find the turn-end row (the only row besides the detach row). @@ -5851,7 +6050,7 @@ func TestAgent_ResumeFromLegacySnapshot_MintsFreshSessionID(t *testing.T) { legacy := &SessionSnapshot[testState]{ SnapshotID: "legacy-1", Event: SnapshotEventInvocationEnd, - Status: SnapshotStatusSucceeded, + Status: SnapshotStatusCompleted, CreatedAt: time.Now(), UpdatedAt: time.Now(), State: &SessionState[testState]{Custom: testState{Counter: 5}}, diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index b7b70521b1..94ea1c2cab 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -97,18 +97,20 @@ const ( // // Sending no fields starts a fresh invocation with empty state. type AgentInit[State any] struct { - // SessionID identifies the session (conversation) to resume. Only valid - // when the agent is server-managed (a session store is configured); - // mutually exclusive with State (a client-managed conversation carries - // its identity inside [SessionState.SessionID]). Alone, it resumes the - // session from its latest snapshot: the most recently updated one that - // is not a failed/aborted dead end. A pending latest snapshot (a - // detached invocation still running) rejects the resume rather than - // racing the background work; if the session's history was forked by - // resuming an earlier snapshot again, the most recently updated branch - // wins, and SnapshotID can pick a branch explicitly. Combined with - // SnapshotID, it asserts which session the snapshot belongs to, and a - // mismatch is rejected. + // SessionID identifies the session (conversation) to resume or start. + // Only valid when the agent is server-managed (a session store is + // configured); mutually exclusive with State (a client-managed + // conversation carries its identity inside [SessionState.SessionID]). + // Alone, it resumes the session from its latest snapshot: the most + // recently updated row, whatever its status. If that row is a failed, + // aborted, or still-pending dead end the resume is rejected (pass + // SnapshotID to continue from a specific earlier point); if the session's + // history was forked by resuming an earlier snapshot again, the most + // recently updated branch wins. If the session has no snapshots yet, a + // brand-new conversation is started under this caller-chosen ID, and + // every snapshot it persists carries it. Combined with SnapshotID, it + // asserts which session the snapshot belongs to, and a mismatch is + // rejected. SessionID string `json:"sessionId,omitempty"` // SnapshotID loads state from a persisted snapshot. Only valid when the // agent is server-managed (a session store is configured). May be @@ -165,6 +167,12 @@ type AgentMetadata struct { Abortable bool `json:"abortable,omitempty"` // StateManagement reports who owns session state. StateManagement AgentStateManagement `json:"stateManagement,omitempty"` + // StateSchema is the JSON schema for the agent's custom session state + // (the Custom field of [SessionState]), inferred from the agent's state + // type. It lets the Dev UI and other reflective callers render or + // validate state without the agent describing it separately. Nil when the + // state type carries no schema to infer (e.g. an unstructured any state). + StateSchema map[string]any `json:"stateSchema,omitempty"` } // AgentOutput is the output when an agent invocation completes. @@ -187,7 +195,8 @@ type AgentOutput[State any] struct { Message *ai.Message `json:"message,omitempty"` // SessionID is the ID of the session this invocation belongs to, // assigned by the framework when the invocation starts. With - // server-managed state, a fresh invocation mints a new ID, resumed + // server-managed state, a fresh invocation adopts the caller-supplied + // session ID (see [AgentInit.SessionID]) or mints a new one, resumed // invocations inherit the chain's, and resuming a snapshot from before // session IDs existed mints a fresh one. With client-managed state it // echoes the ID carried inside the state object @@ -284,9 +293,23 @@ type Artifact struct { // for Dev UI and client-side reconnect flows. It returns the stored // [SessionSnapshot], with [WithStateTransform] applied to its state if // configured. +// +// At least one of SnapshotID or SessionID must be set; they are not +// mutually exclusive. SnapshotID fetches a specific snapshot; SessionID +// alone fetches the session's latest snapshot (via the store's +// [SnapshotReader.GetLatestSnapshot], whatever its status). When both are +// set, the fetched snapshot must belong to that session, or the request +// is rejected. type GetSnapshotRequest struct { - // SnapshotID identifies the snapshot to fetch. - SnapshotID string `json:"snapshotId"` + // SessionID identifies the session whose latest snapshot to fetch. + // Optional when SnapshotID is given. The latest snapshot is the session's + // most recently updated row regardless of status (pending, failed, or + // aborted included). + SessionID string `json:"sessionId,omitempty"` + // SnapshotID identifies the snapshot to fetch. Optional when SessionID is + // given; when both are present the fetched snapshot must belong to that + // session. + SnapshotID string `json:"snapshotId,omitempty"` } // JSONPatch is an RFC 6902 JSON Patch: an ordered list of operations applied in @@ -364,7 +387,7 @@ type SessionSnapshot[State any] struct { // snapshots with the cumulative final state. State *SessionState[State] `json:"state,omitempty"` // Status is the lifecycle state of this snapshot. Empty is treated as - // [SnapshotStatusSucceeded] for backwards compatibility. + // [SnapshotStatusCompleted] for backwards compatibility. Status SnapshotStatus `json:"status,omitempty"` // UpdatedAt is when the snapshot was last written. For pending snapshots // it equals CreatedAt; once the snapshot is finalized it reflects the @@ -408,7 +431,7 @@ const ( // SnapshotEventRecovery indicates the snapshot was written retroactively // by the failure path to preserve the last-good state (everything through // the last successful turn) when a selective snapshot callback had skipped - // persisting it. It is a normal [SnapshotStatusSucceeded] row carrying the + // persisting it. It is a normal [SnapshotStatusCompleted] row carrying the // last good turn's finish reason, resumable like any other; the snapshot // callback is bypassed and never sees this event. SnapshotEventRecovery SnapshotEvent = "recovery" @@ -416,13 +439,13 @@ const ( // SnapshotStatus describes the lifecycle state of a snapshot. Snapshots // written for synchronous turns or invocations are always -// [SnapshotStatusSucceeded] (an empty value is also treated as succeeded +// [SnapshotStatusCompleted] (an empty value is also treated as completed // for backwards compatibility). // // When a client sets [AgentInput.Detach], the server writes a single // snapshot with [SnapshotStatusPending] (and empty state) and returns its // ID immediately. Background processing then either rewrites that snapshot -// with the cumulative final state and [SnapshotStatusSucceeded] / +// with the cumulative final state and [SnapshotStatusCompleted] / // [SnapshotStatusFailed] when the agent finishes, or with // [SnapshotStatusAborted] if the client called abortSnapshot in the // meantime. @@ -433,8 +456,8 @@ const ( // processing the queued inputs. The snapshot will be rewritten with a // terminal status once the background work finishes. SnapshotStatusPending SnapshotStatus = "pending" - // SnapshotStatusSucceeded indicates the snapshot captures a settled state. - SnapshotStatusSucceeded SnapshotStatus = "succeeded" + // SnapshotStatusCompleted indicates the snapshot captures a settled state. + SnapshotStatusCompleted SnapshotStatus = "completed" // SnapshotStatusAborted indicates the snapshot's invocation was aborted // while detached (e.g. via the abortSnapshot companion action). SnapshotStatusAborted SnapshotStatus = "aborted" diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index e386a6a470..f310b40aa6 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -127,7 +127,7 @@ func (s *FileSessionStore[State]) SaveSnapshot( } next.UpdatedAt = now if next.Status == "" { - next.Status = exp.SnapshotStatusSucceeded + next.Status = exp.SnapshotStatusCompleted } if err := s.writeLocked(next); err != nil { @@ -139,23 +139,22 @@ func (s *FileSessionStore[State]) SaveSnapshot( return next, nil } -// snapshotHeader is the subset of snapshot fields needed to decide -// whether a row resolves a session resume. Decoding only these avoids +// snapshotHeader is the subset of snapshot fields needed to match a row to +// a session during the latest-snapshot scan. Decoding only these avoids // materializing every row's full conversation state during the scan. type snapshotHeader struct { - SessionID string `json:"sessionId"` - Status exp.SnapshotStatus `json:"status"` + SessionID string `json:"sessionId"` } // GetLatestSnapshot returns the session's most recently updated snapshot -// that is not a failed/aborted dead end, per the -// [exp.SnapshotReader.GetLatestSnapshot] contract. +// regardless of status, per the [exp.SnapshotReader.GetLatestSnapshot] +// contract. // // Recency is judged by file mtime, which for snapshots written by this // package advances with [exp.SessionSnapshot.UpdatedAt] (each save // creates a fresh temp file and renames it into place); if a file is // touched externally, mtime wins. The scan walks files newest first and -// stops at the first row that matches, so resolving the most recently +// stops at the first row for the session, so resolving the most recently // active session costs one read in the common case. Only header fields // are decoded per candidate (the winner is the only full parse), the // store lock is held per file rather than across the whole scan, and a @@ -180,8 +179,7 @@ func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID if err := json.Unmarshal(data, &h); err != nil { continue } - if h.SessionID != sessionID || - h.Status == exp.SnapshotStatusFailed || h.Status == exp.SnapshotStatusAborted { + if h.SessionID != sessionID { continue } var snap exp.SessionSnapshot[State] diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index 06948d8b49..ed321320fc 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -72,7 +72,7 @@ func TestFileSessionStore(t *testing.T) { t.Errorf("expected nil existing on first save, got %+v", existing) } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, }, nil }) @@ -92,7 +92,7 @@ func TestFileSessionStore(t *testing.T) { store := newFileStore(t) saved, err := store.SaveSnapshot(context.Background(), "", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil }) if err != nil { t.Fatalf("SaveSnapshot: %v", err) @@ -107,7 +107,7 @@ func TestFileSessionStore(t *testing.T) { if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, }, nil }); err != nil { @@ -121,7 +121,7 @@ func TestFileSessionStore(t *testing.T) { } }) - t.Run("DefaultsEmptyStatusToSucceeded", func(t *testing.T) { + t.Run("DefaultsEmptyStatusToCompleted", func(t *testing.T) { store := newFileStore(t) saved, err := store.SaveSnapshot(context.Background(), "", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { @@ -130,8 +130,8 @@ func TestFileSessionStore(t *testing.T) { if err != nil { t.Fatalf("SaveSnapshot: %v", err) } - if saved.Status != exp.SnapshotStatusSucceeded { - t.Errorf("expected Status=succeeded by default, got %q", saved.Status) + if saved.Status != exp.SnapshotStatusCompleted { + t.Errorf("expected Status=completed by default, got %q", saved.Status) } }) @@ -139,7 +139,7 @@ func TestFileSessionStore(t *testing.T) { store := newFileStore(t) if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil }); err != nil { t.Fatalf("seed: %v", err) } @@ -164,7 +164,7 @@ func TestFileSessionStore(t *testing.T) { store := newFileStore(t) saved, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil }) if err != nil { t.Fatalf("seed: %v", err) @@ -176,7 +176,7 @@ func TestFileSessionStore(t *testing.T) { t.Fatal("expected non-nil existing on update") } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, }, nil }) @@ -216,7 +216,7 @@ func TestFileSessionStore(t *testing.T) { store := newFileStore(t) if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil }); err != nil { t.Fatalf("seed: %v", err) } @@ -224,8 +224,8 @@ func TestFileSessionStore(t *testing.T) { if err != nil { t.Fatalf("AbortSnapshot: %v", err) } - if status != exp.SnapshotStatusSucceeded { - t.Errorf("status = %q, want %q (no-op on terminal)", status, exp.SnapshotStatusSucceeded) + if status != exp.SnapshotStatusCompleted { + t.Errorf("status = %q, want %q (no-op on terminal)", status, exp.SnapshotStatusCompleted) } }) @@ -325,7 +325,7 @@ func TestFileSessionStore(t *testing.T) { if _, err := store1.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 42}}, }, nil }); err != nil { @@ -383,7 +383,7 @@ func TestFileSessionStore(t *testing.T) { } if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil }); err != nil { t.Fatalf("SaveSnapshot: %v", err) } @@ -410,7 +410,7 @@ func TestFileSessionStore_FinishReasonPersistsAcrossReopen(t *testing.T) { saved, err := store.SaveSnapshot(context.Background(), "", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, FinishReason: exp.AgentFinishReasonInterrupted, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, }, nil @@ -460,7 +460,7 @@ func TestFileSessionStore_GetLatestSnapshot_SkipsUnparseableFiles(t *testing.T) return &exp.SessionSnapshot[testState]{ SessionID: "sess-1", Event: exp.SnapshotEventTurnEnd, - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, }, nil }); err != nil { t.Fatalf("SaveSnapshot: %v", err) diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go index ca73f1a8d4..ff25796469 100644 --- a/go/ai/exp/localstore/inmemory.go +++ b/go/ai/exp/localstore/inmemory.go @@ -67,11 +67,11 @@ func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID } // GetLatestSnapshot returns the session's most recently updated snapshot -// that is not a failed/aborted dead end, per the -// [exp.SnapshotReader.GetLatestSnapshot] contract. Ties on UpdatedAt are -// broken by SnapshotID so resolution is deterministic. The scan runs -// under the read lock, so the stored rows (which other calls mutate in -// place) never escape it; the winner is returned as a deep copy. +// regardless of status, per the [exp.SnapshotReader.GetLatestSnapshot] +// contract. Ties on UpdatedAt are broken by SnapshotID so resolution is +// deterministic. The scan runs under the read lock, so the stored rows +// (which other calls mutate in place) never escape it; the winner is +// returned as a deep copy. func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { if sessionID == "" { return nil, errors.New("InMemorySessionStore: session ID is empty") @@ -80,8 +80,7 @@ func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessi defer s.mu.RUnlock() var latest *exp.SessionSnapshot[State] for _, snap := range s.snapshots { - if snap.SessionID != sessionID || - snap.Status == exp.SnapshotStatusFailed || snap.Status == exp.SnapshotStatusAborted { + if snap.SessionID != sessionID { continue } if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) || @@ -158,7 +157,7 @@ func (s *InMemorySessionStore[State]) SaveSnapshot( } next.UpdatedAt = now if next.Status == "" { - next.Status = exp.SnapshotStatusSucceeded + next.Status = exp.SnapshotStatusCompleted } copied, err := copySnapshot(next) diff --git a/go/ai/exp/localstore/inmemory_test.go b/go/ai/exp/localstore/inmemory_test.go index 07c27eafe8..32321ff273 100644 --- a/go/ai/exp/localstore/inmemory_test.go +++ b/go/ai/exp/localstore/inmemory_test.go @@ -44,7 +44,7 @@ func TestInMemorySessionStore(t *testing.T) { t.Errorf("expected nil existing on first save, got %+v", existing) } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, }, nil }) @@ -65,7 +65,7 @@ func TestInMemorySessionStore(t *testing.T) { if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, }, nil }); err != nil { @@ -91,7 +91,7 @@ func TestInMemorySessionStore(t *testing.T) { if saved.SnapshotID == "" { t.Error("expected store to generate SnapshotID") } - if saved.Status != exp.SnapshotStatusSucceeded { + if saved.Status != exp.SnapshotStatusCompleted { t.Errorf("expected Status=complete by default, got %q", saved.Status) } }) @@ -100,7 +100,7 @@ func TestInMemorySessionStore(t *testing.T) { store := NewInMemorySessionStore[testState]() if _, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil }); err != nil { t.Fatalf("seed: %v", err) } @@ -125,7 +125,7 @@ func TestInMemorySessionStore(t *testing.T) { store := NewInMemorySessionStore[testState]() saved, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusSucceeded}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil }) if err != nil { t.Fatalf("seed: %v", err) @@ -137,7 +137,7 @@ func TestInMemorySessionStore(t *testing.T) { t.Fatal("expected non-nil existing on update") } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, }, nil }) @@ -169,7 +169,7 @@ func TestInMemorySessionStore_SessionIDs(t *testing.T) { func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { return &exp.SessionSnapshot[testState]{ SessionID: "sess-1", - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, }, nil }); err != nil { diff --git a/go/ai/exp/localstore/store_test.go b/go/ai/exp/localstore/store_test.go index bc8981cddf..24f76200e4 100644 --- a/go/ai/exp/localstore/store_test.go +++ b/go/ai/exp/localstore/store_test.go @@ -62,7 +62,7 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio t.Run("SessionIDKeptWhenProvided", func(t *testing.T) { store := newStore(t) - saved := saveRow(t, store, "a", "sess-keep", "", exp.SnapshotStatusSucceeded) + saved := saveRow(t, store, "a", "sess-keep", "", exp.SnapshotStatusCompleted) if saved.SessionID != "sess-keep" { t.Errorf("SessionID = %q, want provided %q", saved.SessionID, "sess-keep") } @@ -91,7 +91,7 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio SessionID: rewrite, ParentID: existing.ParentID, Event: existing.Event, - Status: exp.SnapshotStatusSucceeded, + Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, }, nil }) @@ -109,8 +109,8 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio // them at invocation start and stamps every row it writes); a row // written without one is session-less even when its parent has one. store := newStore(t) - saveRow(t, store, "parent", "sess-1", "", exp.SnapshotStatusSucceeded) - child := saveRow(t, store, "child", "", "parent", exp.SnapshotStatusSucceeded) + saveRow(t, store, "parent", "sess-1", "", exp.SnapshotStatusCompleted) + child := saveRow(t, store, "child", "", "parent", exp.SnapshotStatusCompleted) if child.SessionID != "" { t.Errorf("expected session-less row, got SessionID %q", child.SessionID) } @@ -120,13 +120,13 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio // IDs deliberately sort against write order so a recency bug (or an // accidental reliance on the tie-break) cannot pass by luck. store := newStore(t) - saveRow(t, store, "z", "sess-1", "", exp.SnapshotStatusSucceeded) + saveRow(t, store, "z", "sess-1", "", exp.SnapshotStatusCompleted) tick() - saveRow(t, store, "m", "sess-1", "z", exp.SnapshotStatusSucceeded) + saveRow(t, store, "m", "sess-1", "z", exp.SnapshotStatusCompleted) tick() - saveRow(t, store, "a", "sess-1", "m", exp.SnapshotStatusSucceeded) + saveRow(t, store, "a", "sess-1", "m", exp.SnapshotStatusCompleted) tick() - saveRow(t, store, "x", "sess-other", "", exp.SnapshotStatusSucceeded) + saveRow(t, store, "x", "sess-other", "", exp.SnapshotStatusCompleted) latest, err := store.GetLatestSnapshot(ctx, "sess-1") if err != nil { @@ -147,16 +147,16 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio // row (e.g. a detach finalize landing after other branches were // written) moves it to the front. store := newStore(t) - saveRow(t, store, "root", "sess-1", "", exp.SnapshotStatusSucceeded) + saveRow(t, store, "root", "sess-1", "", exp.SnapshotStatusCompleted) tick() saveRow(t, store, "b1", "sess-1", "root", exp.SnapshotStatusPending) tick() - saveRow(t, store, "b2", "sess-1", "root", exp.SnapshotStatusSucceeded) + saveRow(t, store, "b2", "sess-1", "root", exp.SnapshotStatusCompleted) tick() if _, err := store.SaveSnapshot(ctx, "b1", func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { rewritten := *existing - rewritten.Status = exp.SnapshotStatusSucceeded + rewritten.Status = exp.SnapshotStatusCompleted rewritten.State = &exp.SessionState[testState]{Custom: testState{Counter: 2}} return &rewritten, nil }); err != nil { @@ -172,11 +172,12 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio } }) - t.Run("GetLatestSnapshotSkipsDeadEnds", func(t *testing.T) { - // Failed and aborted rows are dead ends: even when newest, they - // never hide the session's last good snapshot. + t.Run("GetLatestSnapshotReturnsLatestAnyStatus", func(t *testing.T) { + // The latest row is returned whatever its status: failed and aborted + // tips are no longer skipped. Deciding a tip is a dead end is the + // resume path's job, not the store's. Here the newest row is aborted. store := newStore(t) - saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusSucceeded) + saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusCompleted) tick() saveRow(t, store, "b", "sess-1", "a", exp.SnapshotStatusFailed) tick() @@ -186,18 +187,18 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio if err != nil { t.Fatalf("GetLatestSnapshot: %v", err) } - if latest == nil || latest.SnapshotID != "a" { - t.Errorf("latest = %+v, want last good snapshot a", latest) + if latest == nil || latest.SnapshotID != "c" || latest.Status != exp.SnapshotStatusAborted { + t.Errorf("latest = %+v, want newest row c (aborted)", latest) } }) t.Run("GetLatestSnapshotPendingReturned", func(t *testing.T) { - // A pending row is not skipped: it marks a detached invocation - // that is still running, and the runtime needs to see it to - // reject the resume instead of silently racing the background + // A pending row is returned like any other: it marks a detached + // invocation that is still running, and the runtime needs to see it + // to reject the resume instead of silently racing the background // work. store := newStore(t) - saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusSucceeded) + saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusCompleted) tick() saveRow(t, store, "b", "sess-1", "a", exp.SnapshotStatusPending) @@ -212,7 +213,7 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio t.Run("GetLatestSnapshotUnknownSession", func(t *testing.T) { store := newStore(t) - saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusSucceeded) + saveRow(t, store, "a", "sess-1", "", exp.SnapshotStatusCompleted) latest, err := store.GetLatestSnapshot(ctx, "sess-unknown") if err != nil { t.Fatalf("GetLatestSnapshot: %v", err) diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index b2cd73073d..4c9d92a50d 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -63,28 +63,27 @@ type SnapshotReader[State any] interface { // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) - // GetLatestSnapshot resolves the snapshot a session would resume from: - // the session's most recently updated row that is not a dead end, as a - // full row (the runtime loads its state to resume). Returns nil if the - // session has no such row (unknown session ID, or every row is a dead - // end), and an error if sessionID is empty. + // GetLatestSnapshot returns the session's most recently updated + // snapshot as a full row (the runtime loads its state to resume), + // whatever its status: a pending, failed, or aborted row is returned + // like any other. Returns nil if the session has no rows (unknown + // session ID), and an error if sessionID is empty. // // "Most recently updated" means the greatest // [SessionSnapshot.UpdatedAt], falling back to CreatedAt on rows that // lack one; ties may be broken arbitrarily but deterministically (e.g. - // by SnapshotID). Failed and aborted rows are dead ends (resuming from - // them is rejected), so they are skipped and a branch that died never - // hides the session's last good snapshot. A pending row is NOT - // skipped: it marks a detached invocation that is still running, and - // surfacing it lets the agent runtime reject the resume instead of - // silently racing the background work. + // by SnapshotID). The latest row is returned unconditionally, so the + // two callers can each apply their own policy: the getSnapshot + // companion action surfaces it verbatim (a client reconnecting wants to + // see a pending/failed tip), and the session-ID resume path rejects a + // failed/aborted/pending tip rather than continuing from a dead end. // - // The contract is a status-filtered max-timestamp lookup, so stores - // can implement it as a single indexed query (e.g. WHERE sessionId = ? - // AND status NOT IN ('failed', 'aborted') ORDER BY updatedAt DESC - // LIMIT 1). ParentID is informational lineage and plays no part in - // resolution: when a session's history was forked by re-resuming an - // earlier snapshot, the most recently updated branch simply wins. + // The contract is a plain max-timestamp lookup, so stores can implement + // it as a single indexed query (e.g. WHERE sessionId = ? ORDER BY + // updatedAt DESC LIMIT 1). ParentID is informational lineage and plays + // no part in resolution: when a session's history was forked by + // re-resuming an earlier snapshot, the most recently updated branch + // simply wins. GetLatestSnapshot(ctx context.Context, sessionID string) (*SessionSnapshot[State], error) } @@ -108,7 +107,7 @@ type SnapshotWriter[State any] interface { // from the existing row on update. // - UpdatedAt: stamped to the wall clock on every commit. // - Status: if the snapshot returned by fn has Status="", it is - // defaulted to [SnapshotStatusSucceeded] (the common case for + // defaulted to [SnapshotStatusCompleted] (the common case for // synchronous turn-end and invocation-end writes). Callers // writing a pending row must set Status explicitly. // @@ -220,8 +219,9 @@ func cloneArtifacts(arts []*Artifact) []*Artifact { // registering them, when the agent has a [SessionStore] configured: // // - The agent's name under [api.ActionTypeAgentSnapshot] — getSnapshot, -// the remote counterpart to [SessionStore.GetSnapshot] for Dev UI and -// non-Go clients. Local Go callers use the store reference directly. +// the remote counterpart to [SnapshotReader.GetSnapshot] (by snapshot +// ID) and [SnapshotReader.GetLatestSnapshot] (by session ID) for Dev UI +// and non-Go clients. Local Go callers use the store reference directly. // // - The agent's name under [api.ActionTypeAgentAbort] — abortSnapshot, // created only when the store also implements [SnapshotAborter] (which @@ -246,26 +246,50 @@ func newSnapshotActions[State any]( } getSnapshotAction := core.NewAction(agentName, api.ActionTypeAgentSnapshot, nil, nil, func(ctx context.Context, req *GetSnapshotRequest) (*SessionSnapshot[State], error) { - if req == nil || req.SnapshotID == "" { - return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId is required") - } - snap, err := store.GetSnapshot(ctx, req.SnapshotID) - if err != nil { - return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) + if req == nil || (req.SnapshotID == "" && req.SessionID == "") { + return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId or sessionId is required") } - if snap == nil { - return nil, core.NewError(core.NOT_FOUND, "getSnapshot: snapshot %q not found", req.SnapshotID) + + // Resolve the snapshot. A snapshot ID fetches that exact row; a + // session ID alone fetches the session's latest row (whatever + // its status). When both are present the snapshot ID picks the + // row and the session ID asserts it belongs to that session, + // mirroring AgentInit's combined-ID check. + var ( + snap *SessionSnapshot[State] + err error + ) + if req.SnapshotID != "" { + snap, err = store.GetSnapshot(ctx, req.SnapshotID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) + } + if snap == nil { + return nil, core.NewError(core.NOT_FOUND, "getSnapshot: snapshot %q not found", req.SnapshotID) + } + if req.SessionID != "" && snap.SessionID != req.SessionID { + return nil, core.NewError(core.INVALID_ARGUMENT, + "getSnapshot: snapshot %q does not belong to session %q (snapshot's session: %q)", req.SnapshotID, req.SessionID, snap.SessionID) + } + } else { + snap, err = store.GetLatestSnapshot(ctx, req.SessionID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) + } + if snap == nil { + return nil, core.NewError(core.NOT_FOUND, "getSnapshot: no snapshot found for session %q", req.SessionID) + } } // Return a normalized copy: the documented defaults (empty - // status means succeeded, zero UpdatedAt means CreatedAt) are + // status means completed, zero UpdatedAt means CreatedAt) are // resolved server-side so remote clients don't reimplement // them, and the state transform shapes what leaves the server. // A failed snapshot's state is its last-good state, so it is // returned like any other. resp := *snap if resp.Status == "" { - resp.Status = SnapshotStatusSucceeded + resp.Status = SnapshotStatusCompleted } if resp.UpdatedAt.IsZero() { resp.UpdatedAt = resp.CreatedAt diff --git a/go/ai/exp/teststore_test.go b/go/ai/exp/teststore_test.go index 079dd0f2ed..969a169953 100644 --- a/go/ai/exp/teststore_test.go +++ b/go/ai/exp/teststore_test.go @@ -76,8 +76,7 @@ func (s *testInMemStore[State]) GetLatestSnapshot(_ context.Context, sessionID s defer s.mu.RUnlock() var latest *SessionSnapshot[State] for _, snap := range s.snapshots { - if snap.SessionID != sessionID || - snap.Status == SnapshotStatusFailed || snap.Status == SnapshotStatusAborted { + if snap.SessionID != sessionID { continue } if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) || @@ -147,7 +146,7 @@ func (s *testInMemStore[State]) SaveSnapshot( } next.UpdatedAt = now if next.Status == "" { - next.Status = SnapshotStatusSucceeded + next.Status = SnapshotStatusCompleted } copied, err := testCopySnapshot(next) diff --git a/go/core/schemas.config b/go/core/schemas.config index 3bde599227..c2e72c6b7c 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1276,18 +1276,20 @@ Sending no fields starts a fresh invocation with empty state. . AgentInit.sessionId doc -SessionID identifies the session (conversation) to resume. Only valid -when the agent is server-managed (a session store is configured); -mutually exclusive with State (a client-managed conversation carries -its identity inside [SessionState.SessionID]). Alone, it resumes the -session from its latest snapshot: the most recently updated one that -is not a failed/aborted dead end. A pending latest snapshot (a -detached invocation still running) rejects the resume rather than -racing the background work; if the session's history was forked by -resuming an earlier snapshot again, the most recently updated branch -wins, and SnapshotID can pick a branch explicitly. Combined with -SnapshotID, it asserts which session the snapshot belongs to, and a -mismatch is rejected. +SessionID identifies the session (conversation) to resume or start. +Only valid when the agent is server-managed (a session store is +configured); mutually exclusive with State (a client-managed +conversation carries its identity inside [SessionState.SessionID]). +Alone, it resumes the session from its latest snapshot: the most +recently updated row, whatever its status. If that row is a failed, +aborted, or still-pending dead end the resume is rejected (pass +SnapshotID to continue from a specific earlier point); if the session's +history was forked by resuming an earlier snapshot again, the most +recently updated branch wins. If the session has no snapshots yet, a +brand-new conversation is started under this caller-chosen ID, and +every snapshot it persists carries it. Combined with SnapshotID, it +asserts which session the snapshot belongs to, and a mismatch is +rejected. . AgentInit.snapshotId doc @@ -1348,7 +1350,8 @@ It wraps AgentResult with framework-managed fields. AgentOutput.sessionId doc SessionID is the ID of the session this invocation belongs to, assigned by the framework when the invocation starts. With -server-managed state, a fresh invocation mints a new ID, resumed +server-managed state, a fresh invocation adopts the caller-supplied +session ID (see [AgentInit.SessionID]) or mints a new one, resumed invocations inherit the chain's, and resuming a snapshot from before session IDs existed mints a fresh one. With client-managed state it echoes the ID carried inside the state object @@ -1661,7 +1664,7 @@ Event is what triggered this snapshot. SessionSnapshot.status doc Status is the lifecycle state of this snapshot. Empty is treated as -[SnapshotStatusSucceeded] for backwards compatibility. +[SnapshotStatusCompleted] for backwards compatibility. . SessionSnapshot.finishReason doc @@ -1716,7 +1719,7 @@ SnapshotEventRecovery doc SnapshotEventRecovery indicates the snapshot was written retroactively by the failure path to preserve the last-good state (everything through the last successful turn) when a selective snapshot callback had skipped -persisting it. It is a normal [SnapshotStatusSucceeded] row carrying the +persisting it. It is a normal [SnapshotStatusCompleted] row carrying the last good turn's finish reason, resumable like any other; the snapshot callback is bypassed and never sees this event. . @@ -1731,13 +1734,13 @@ SnapshotStatus pkg ai/exp SnapshotStatus doc SnapshotStatus describes the lifecycle state of a snapshot. Snapshots written for synchronous turns or invocations are always -[SnapshotStatusSucceeded] (an empty value is also treated as succeeded +[SnapshotStatusCompleted] (an empty value is also treated as completed for backwards compatibility). When a client sets [AgentInput.Detach], the server writes a single snapshot with [SnapshotStatusPending] (and empty state) and returns its ID immediately. Background processing then either rewrites that snapshot -with the cumulative final state and [SnapshotStatusSucceeded] / +with the cumulative final state and [SnapshotStatusCompleted] / [SnapshotStatusFailed] when the agent finishes, or with [SnapshotStatusAborted] if the client called abortSnapshot in the meantime. @@ -1749,8 +1752,8 @@ processing the queued inputs. The snapshot will be rewritten with a terminal status once the background work finishes. . -SnapshotStatusSucceeded doc -SnapshotStatusSucceeded indicates the snapshot captures a settled state. +SnapshotStatusCompleted doc +SnapshotStatusCompleted indicates the snapshot captures a settled state. . SnapshotStatusAborted doc @@ -1838,11 +1841,26 @@ when the agent has a session store configured. The action is intended for Dev UI and client-side reconnect flows. It returns the stored [SessionSnapshot], with [WithStateTransform] applied to its state if configured. + +At least one of SnapshotID or SessionID must be set; they are not +mutually exclusive. SnapshotID fetches a specific snapshot; SessionID +alone fetches the session's latest snapshot (via the store's +[SnapshotReader.GetLatestSnapshot], whatever its status). When both are +set, the fetched snapshot must belong to that session, or the request +is rejected. . -GetSnapshotRequest.snapshotId noomitempty GetSnapshotRequest.snapshotId doc -SnapshotID identifies the snapshot to fetch. +SnapshotID identifies the snapshot to fetch. Optional when SessionID is +given; when both are present the fetched snapshot must belong to that +session. +. + +GetSnapshotRequest.sessionId doc +SessionID identifies the session whose latest snapshot to fetch. +Optional when SnapshotID is given. The latest snapshot is the session's +most recently updated row regardless of status (pending, failed, or +aborted included). . # AbortSnapshotRequest @@ -1899,6 +1917,15 @@ Abortable reports whether the agent's invocations can be aborted (true when the store implements [SnapshotAborter]). . +AgentMetadata.stateSchema type map[string]any +AgentMetadata.stateSchema doc +StateSchema is the JSON schema for the agent's custom session state +(the Custom field of [SessionState]), inferred from the agent's state +type. It lets the Dev UI and other reflective callers render or +validate state without the agent describing it separately. Nil when the +state type carries no schema to infer (e.g. an unstructured any state). +. + AgentStateManagement pkg ai/exp AgentStateManagement doc diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index fb6d325c38..792abe1a5a 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -1061,10 +1061,10 @@ func TestHandlerAgentRef(t *testing.T) { if snap.SnapshotID != res.SnapshotID || snap.SessionID != res.SessionID { t.Errorf("snapshot identity = %q/%q, want %q/%q", snap.SnapshotID, snap.SessionID, res.SnapshotID, res.SessionID) } - // The action normalizes the implicit empty status to "succeeded" + // The action normalizes the implicit empty status to "completed" // so remote clients don't reimplement the default. - if snap.Status != "succeeded" { - t.Errorf("status = %q, want %q", snap.Status, "succeeded") + if snap.Status != "completed" { + t.Errorf("status = %q, want %q", snap.Status, "completed") } if len(snap.State) == 0 { t.Error("snapshot must carry state") @@ -1094,8 +1094,8 @@ func TestHandlerAgentRef(t *testing.T) { t.Fatalf("abortSnapshot status = %d, body = %s", code, body) } snap := parseSnapshot(t, body) - if snap.Status != "succeeded" { - t.Errorf("status = %q, want %q (abort of a terminal snapshot is a no-op)", snap.Status, "succeeded") + if snap.Status != "completed" { + t.Errorf("status = %q, want %q (abort of a terminal snapshot is a no-op)", snap.Status, "completed") } }) @@ -1123,8 +1123,8 @@ func TestHandlerAgentRef(t *testing.T) { } snap := parseSnapshot(t, body) if snap.Status != "pending" { - if snap.Status != "succeeded" { - t.Fatalf("final status = %q, want %q; body = %s", snap.Status, "succeeded", body) + if snap.Status != "completed" { + t.Fatalf("final status = %q, want %q; body = %s", snap.Status, "completed", body) } if len(snap.State) == 0 { t.Error("finalized snapshot must carry the cumulative state") diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index 6f1648b6cc..2fe76ab9db 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -181,8 +181,8 @@ func openAgent(ctx context.Context, inputCh <-chan string, a *aix.Agent[any], la // 2. ignore it and start a fresh conversation, // 3. go back to the agent list. // -// Returns the snapshot to resume from (option 1, succeeded) or nil -// (option 2, or option 1 when the snapshot terminated non-succeeded). +// Returns the snapshot to resume from (option 1, completed) or nil +// (option 2, or option 1 when the snapshot terminated non-completed). // ok=false means the user chose 3 or the context was canceled. // // Crucially, options that imply "use this conversation" return the @@ -214,7 +214,7 @@ func handlePending(ctx context.Context, inputCh <-chan string, a *aix.Agent[any] return nil, true } fmt.Printf("Done (%s).\n", final.Status) - if final.Status != aix.SnapshotStatusSucceeded { + if final.Status != aix.SnapshotStatusCompleted { // failed / aborted snapshots aren't resumable; the // agent runtime would reject WithSnapshotID on them. // Fall through to a fresh chat instead. @@ -240,7 +240,7 @@ func handlePending(ctx context.Context, inputCh <-chan string, a *aix.Agent[any] // recent terminal snapshot (returns the snapshot pointer), or start // fresh (returns nil). func pickSession(ctx context.Context, inputCh <-chan string, a *aix.Agent[any], latest *aix.SessionSnapshot[any]) (*aix.SessionSnapshot[any], bool) { - if latest == nil || latest.Status != aix.SnapshotStatusSucceeded { + if latest == nil || latest.Status != aix.SnapshotStatusCompleted { fmt.Printf("\nStarting a new conversation with %s.\n", a.Name()) return nil, true } From 282826f7752ed7ee9261c7205e891bb9823bb362 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 16:38:12 -0700 Subject: [PATCH 108/141] fix(py): escape reserved words in the schema typing generator schema_to_typing.py emitted JSON property names verbatim as Python attributes. JsonPatchOperation.from is a Python keyword, so the generator produced invalid Python and generate_schema_typing failed, leaving _typing.py stale (missing the JSON Patch types and AgentStreamChunk's customPatch). When a field's snake_case name is a Python keyword, suffix the Python attribute (from -> from_) and pin an explicit Pydantic alias to the original wire key, so the JSON shape is unchanged. The Field(...) form is forced for these so the alias survives even on plain scalar optionals, where a bare None default would otherwise drop it. Regenerate _typing.py: it now also picks up the current schema (the agent session-flow changes, SnapshotStatus completed, and the JSON Patch types). --- .../genkit/src/genkit/_core/_typing.py | 41 +++++++++++++++++-- py/tools/schema_to_typing/schema_to_typing.py | 15 ++++++- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index f91aab8b75..963c4352fd 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -55,6 +55,17 @@ class AgentStateManagement(StrEnum): CLIENT = 'client' +class JsonPatchOp(StrEnum): + """JsonPatchOp data type class.""" + + ADD = 'add' + REMOVE = 'remove' + REPLACE = 'replace' + MOVE = 'move' + COPY = 'copy' + TEST = 'test' + + class SnapshotEvent(StrEnum): """SnapshotEvent data type class.""" @@ -68,7 +79,7 @@ class SnapshotStatus(StrEnum): """SnapshotStatus data type class.""" PENDING = 'pending' - SUCCEEDED = 'succeeded' + COMPLETED = 'completed' ABORTED = 'aborted' FAILED = 'failed' @@ -157,6 +168,7 @@ class AgentMetadata(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) state_management: AgentStateManagement = Field(...) abortable: bool = Field(...) + state_schema: StateSchema | None = None class AgentOutput(GenkitModel): @@ -186,7 +198,7 @@ class AgentStreamChunk(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) model_chunk: ModelResponseChunk | None = None - status: Any | None = Field(default=None) + custom_patch: JsonPatch | None = None artifact: Artifact | None = None turn_end: TurnEnd | None = None @@ -204,7 +216,18 @@ class GetSnapshotRequest(GenkitModel): """Model for getsnapshotrequest data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str = Field(...) + snapshot_id: str | None = None + session_id: str | None = None + + +class JsonPatchOperation(GenkitModel): + """Model for jsonpatchoperation data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + op: JsonPatchOp = Field(...) + path: str = Field(...) + from_: str | None = Field(default=None, alias='from') + value: Any | None = Field(default=None) class SessionSnapshot(GenkitModel): @@ -948,6 +971,12 @@ class Resume(GenkitModel): metadata: Metadata | None = None +class StateSchema(GenkitModel): + """Model for stateschema data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + + class Details(GenkitModel): """Model for details data.""" @@ -1072,6 +1101,12 @@ class Part( TraceEvent = SpanStartEvent | SpanEndEvent +class JsonPatch(RootModel[list[JsonPatchOperation]]): + """Root model for jsonpatch.""" + + root: list[JsonPatchOperation] + + class EvalResponse(RootModel[list[EvalFnResponse]]): """Root model for evalresponse.""" diff --git a/py/tools/schema_to_typing/schema_to_typing.py b/py/tools/schema_to_typing/schema_to_typing.py index a71dd545d8..1bc0ed05de 100644 --- a/py/tools/schema_to_typing/schema_to_typing.py +++ b/py/tools/schema_to_typing/schema_to_typing.py @@ -7,6 +7,7 @@ from __future__ import annotations import json +import keyword import re import sys from datetime import datetime @@ -285,12 +286,24 @@ def _emit_model( for k, v in props.items(): # Use schema_ for OutputConfig.schema to avoid shadowing GenkitModel.schema snake = _camel_to_snake(k) + force_field = False if name == 'OutputConfig' and snake == 'schema': field_name = 'schema_' alias_extra = ", alias='schema'" elif snake in ('schema_', 'schema'): field_name = 'schema' if name != 'OutputConfig' else 'schema_' alias_extra = ", alias='schema'" if name == 'OutputConfig' else '' + elif keyword.iskeyword(snake): + # Field name is a Python reserved word (e.g. JSON Patch's `from`), + # which is a keyword only in Python. Suffix the Python attribute + # and pin the wire alias to the original key so the JSON shape is + # unchanged. to_camel cannot recover the original name from the + # suffixed one, so the alias must be explicit and emitted even for + # plain scalar fields (where the default would otherwise be a bare + # None that drops the alias). + field_name = snake + '_' + alias_extra = f', alias={k!r}' + force_field = True else: field_name = snake alias_extra = '' @@ -310,7 +323,7 @@ def _emit_model( else: default_val = ( f'Field(default=None{desc_extra}{alias_extra})' - if '|' in py_type_str or py_type_str == 'Any' + if '|' in py_type_str or py_type_str == 'Any' or force_field else 'None' ) lines.append(f' {field_name}: {py_type_str} | None = {default_val}') From a072bdc1c412d0284aae626e3ecfb2b7ab850ee2 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 17:36:34 -0700 Subject: [PATCH 109/141] refactor(go/exp): snapshot at turn end only Collapse the agent snapshot machinery to a single persistence point: the end of each successful turn. A failed turn never snapshots, so the newest snapshot is always the last good turn, and the runtime no longer needs separate invocation-end or recovery writes to guarantee a resume point. Remove the supporting complexity this enables: - SnapshotCallback, SnapshotContext, WithSnapshotCallback and WithSnapshotOn (the selective cadence). - The SnapshotEvent type and the event field on SessionSnapshot, regenerated from agent.ts through genkit-schema.json and gen.go. Detach is distinguished by its pending/terminal status, not an event. - The invocation-end snapshot: the synchronous-completion output now reports the last turn-end snapshot. - The recovery snapshot: the failure path returns the last turn-end snapshot's ID (server-managed) or the last-good state inline (client-managed). - The session version counter and the turn/invocation dedup it fed. Fold maybeSnapshot and persistSnapshotLocked into snapshotTurnEnd, replace lastSnapshot/lastSnapshotVersion with a single lastSnapshotID, and reduce lastGood{State,Version,FinishReason} to a lastGoodState kept only for the client-managed failure path. Behavior change for custom agents (prompt-backed agents are unaffected): state mutated after the turn loop, and an AgentResult.FinishReason override, ride on the returned output but are no longer persisted to a snapshot. --- genkit-tools/common/src/types/agent.ts | 44 +--- genkit-tools/genkit-schema.json | 15 +- go/ai/exp/agent.go | 280 ++++++++------------- go/ai/exp/agent_test.go | 323 ++++++------------------- go/ai/exp/gen.go | 43 +--- go/ai/exp/localstore/file_test.go | 1 - go/ai/exp/localstore/store_test.go | 2 - go/ai/exp/option.go | 27 --- go/ai/exp/session.go | 33 +-- go/core/schemas.config | 58 +---- go/genkit/exp/routes_test.go | 1 - go/genkit/genkit.go | 6 +- go/genkit/servers_test.go | 2 - go/samples/basic-agents/main.go | 3 - 14 files changed, 205 insertions(+), 633 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index a7107e672b..46889518f7 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -36,30 +36,6 @@ export const ArtifactSchema = z.object({ }); export type Artifact = z.infer; -/** - * Zod schema for snapshot event. - * - * - `turnEnd`: snapshot was triggered at the end of a turn. - * - `invocationEnd`: snapshot was triggered at the end of the invocation. - * - `detach`: snapshot was created when the client detached the invocation - * and the flow continues in the background. Initially written with - * `pending` status (and empty state) and rewritten with a terminal - * status and the final cumulative state once the background work - * finishes. - * - `recovery`: snapshot was written retroactively by the failure path to - * preserve the last-good state (everything through the last successful - * turn) when a selective snapshot callback had skipped persisting it. - * It is a normal `completed` row carrying the last good turn's - * `finishReason`, resumable like any other; the callback is bypassed. - */ -export const SnapshotEventSchema = z.enum([ - 'turnEnd', - 'invocationEnd', - 'detach', - 'recovery', -]); -export type SnapshotEvent = z.infer; - /** * Zod schema for a snapshot's lifecycle status. * @@ -236,13 +212,12 @@ export const AgentOutputSchema = z.object({ */ sessionId: z.string().optional(), /** - * ID of the newest snapshot capturing this invocation: the - * invocation-end snapshot, or the latest earlier snapshot when that - * write was skipped. Empty when no store is configured or the - * invocation persisted nothing. When `finishReason` is `detached` it - * is the pending detach snapshot; when `failed`, the most recent - * snapshot capturing the last-good state: everything through the last - * successful turn (see the `recovery` snapshot event). + * ID of the most recent turn-end snapshot for this invocation. Empty + * when no store is configured or no turn committed. When `finishReason` + * is `detached` it is the pending detach snapshot; when `failed`, it is + * the last committed turn's snapshot (the resume point, holding state + * through the last successful turn and excluding the failed turn's + * partial mutations). */ snapshotId: z.string().optional(), /** @@ -284,9 +259,8 @@ export type AgentOutput = z.infer; export const TurnEndSchema = z.object({ /** * ID of the snapshot persisted at the end of this turn. Empty if no - * snapshot was written (no store configured, the callback declined, - * nothing changed since the last snapshot, or snapshots were suspended - * after detach). + * snapshot was written (no store configured, the turn failed, or + * snapshots were suspended after detach). */ snapshotId: z.string().optional(), /** @@ -391,8 +365,6 @@ export const SessionSnapshotSchema = z.object({ createdAt: z.string(), /** When the snapshot was last written (RFC 3339). Equals `createdAt` until rewritten. */ updatedAt: z.string().optional(), - /** What triggered this snapshot. */ - event: SnapshotEventSchema, /** Lifecycle state of this snapshot. Empty is treated as `completed`. */ status: SnapshotStatusSchema.optional(), /** diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index f13ebcf0b9..114bb91950 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -269,9 +269,6 @@ "updatedAt": { "type": "string" }, - "event": { - "$ref": "#/$defs/SnapshotEvent" - }, "status": { "$ref": "#/$defs/SnapshotStatus" }, @@ -287,8 +284,7 @@ }, "required": [ "snapshotId", - "createdAt", - "event" + "createdAt" ], "additionalProperties": false }, @@ -314,15 +310,6 @@ }, "additionalProperties": false }, - "SnapshotEvent": { - "type": "string", - "enum": [ - "turnEnd", - "invocationEnd", - "detach", - "recovery" - ] - }, "SnapshotStatus": { "type": "string", "enum": [ diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index b8f59bfd03..60ccd45a0d 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -58,20 +58,24 @@ type SessionRunner[State any] struct { // directly. TurnIndex int - snapshotCallback SnapshotCallback[State] onStartTurn func() onEndTurn func(ctx context.Context) collectTurnOutput func() any - // snapMu serializes snapshot persistence with the detach handler's - // suspend-and-capture. lastSnapshot and lastSnapshotVersion are - // written under it; the terminal paths that read them without it - // (handleFnDone, failedOutput) run after fn completes, with a - // happens-before edge through the fnDone channel. - snapMu sync.Mutex - snapshotsSuspended bool - lastSnapshot *SessionSnapshot[State] - lastSnapshotVersion uint64 + // snapMu serializes the turn-end snapshot write (snapshotTurnEnd) + // against the detach handler's suspend-and-capture (suspendSnapshots). + // snapshotsSuspended and lastSnapshotID are written under it; the + // terminal paths that read lastSnapshotID without it (handleFnDone, + // failedOutput) run after fn completes, with a happens-before edge + // through the fnDone channel. + snapMu sync.Mutex + snapshotsSuspended bool + // lastSnapshotID is the ID of the most recent turn-end snapshot, or of + // the snapshot the invocation resumed from until the first turn commits, + // or "" when no store is configured or nothing has been written yet. It + // is the parent of the next snapshot and the resume point the failed and + // detached outputs report. + lastSnapshotID string // lastTurnFinishReason is the finish reason reported by the most recent // turn (via the [TurnResult] its callback returned), or "" if the turn @@ -81,9 +85,9 @@ type SessionRunner[State any] struct { // to the fn goroutine (Run and its synchronous onEndTurn callback) until // fn completes, after which the terminal paths read it with a // happens-before edge through the fnDone channel, so no lock is needed. - // The same confinement applies to lastTurnFailed and the lastGood* - // fields below; the terminal paths that read them (handleFnDone and - // the detach-failure paths) all wait on fnDone first. + // The same confinement applies to lastTurnFailed and lastGoodState; the + // terminal paths that read them (handleFnDone and the detach-failure + // paths) all wait on fnDone first. lastTurnFinishReason AgentFinishReason // lastTurnFailed reports whether the most recent turn ended in error. @@ -92,39 +96,27 @@ type SessionRunner[State any] struct { // lastGoodState is a deep copy of the session state as of the most // recent successful turn (or the initial state when no turn has - // completed yet), captured regardless of whether the snapshot callback - // persisted that turn. lastGoodVersion is the session version at that - // capture and lastGoodFinishReason that turn's reported reason. The - // failure path returns (client-managed) or persists (server-managed - // recovery snapshot) this state. - lastGoodState *SessionState[State] - lastGoodVersion uint64 - lastGoodFinishReason AgentFinishReason -} - -// parentSnapshotID returns the ID of the most recent snapshot in this -// invocation (used to chain new snapshots via ParentID), or "" if no -// snapshot has been written yet. -func (s *SessionRunner[State]) parentSnapshotID() string { - if s.lastSnapshot == nil { - return "" - } - return s.lastSnapshot.SnapshotID -} - -// suspendSnapshots stops all further snapshot persistence for this -// invocation and returns the ID of the newest persisted snapshot. Taking -// snapMu makes the two steps atomic with respect to an in-flight turn-end -// write: a write already inside maybeSnapshot completes first (so the -// returned parent is current, not stale), and any later turn end observes -// the suspension and skips its write. Called by the detach handler, after -// which the queued inputs roll into a single finalize rewrite of the -// pending row. + // completed yet), kept only for client-managed agents (no store). The + // client-managed failure path returns it inline so the caller resumes + // from the last committed turn, excluding the failed turn's partial + // mutations. Nil and unused for server-managed agents, whose failure + // path returns the last turn-end snapshot instead. + lastGoodState *SessionState[State] +} + +// suspendSnapshots stops all further turn-end snapshot writes for this +// invocation and returns the ID of the newest persisted snapshot (the +// parent for the detach handler's pending row). Taking snapMu makes the +// two steps atomic with respect to an in-flight turn-end write: a write +// already inside snapshotTurnEnd completes first (so the returned parent +// is current, not stale), and any later turn end observes the suspension +// and skips its write. Called by the detach handler, after which the +// queued inputs roll into a single finalize rewrite of the pending row. func (s *SessionRunner[State]) suspendSnapshots() (parentID string) { s.snapMu.Lock() defer s.snapMu.Unlock() s.snapshotsSuspended = true - return s.parentSnapshotID() + return s.lastSnapshotID } // TurnResult is the optional return value of a [SessionRunner.Run] per-turn @@ -215,29 +207,26 @@ func (s *SessionRunner[State]) endTurn(ctx context.Context, reason AgentFinishRe s.lastTurnFailed = failed s.onEndTurn(ctx) if !failed { - s.recordLastGood() + s.captureLastGood() } s.TurnIndex++ } -// recordLastGood captures the current session state as the last-good -// recovery point. Called once at session start and after every successful -// turn, whether or not the snapshot callback persisted that turn. Runs -// after the turn-end snapshot check so that when the newest snapshot -// already captures this exact version, the deep copy is skipped; -// recoverySnapshotID then resolves to that snapshot's ID without reading -// lastGoodState. -func (s *SessionRunner[State]) recordLastGood() { - s.mu.RLock() - version := s.version - persisted := s.lastSnapshot != nil && version == s.lastSnapshotVersion - if !persisted { - state := s.copyStateLocked() - s.lastGoodState = &state +// captureLastGood deep-copies the committed session state as the +// client-managed failure fallback: the state a failed invocation returns +// inline (see failedOutput), excluding a later failed turn's partial +// mutations. Called once at session start (the initial state is the +// fallback until a turn completes) and after every successful turn. It is +// a no-op for server-managed agents, whose failure path returns the last +// turn-end snapshot instead, so they pay no per-turn copy. +func (s *SessionRunner[State]) captureLastGood() { + if s.store != nil { + return } + s.mu.RLock() + state := s.copyStateLocked() s.mu.RUnlock() - s.lastGoodVersion = version - s.lastGoodFinishReason = s.lastTurnFinishReason + s.lastGoodState = &state } // Result returns an [AgentResult] populated from the current session state: @@ -272,16 +261,23 @@ func (s *SessionRunner[State]) invocationReason(result *AgentResult) AgentFinish return s.lastTurnFinishReason } -// maybeSnapshot creates a snapshot if conditions are met (store configured, -// snapshots not suspended by detach, callback approves, state changed). -// Returns the snapshot ID or empty string. finishReason is recorded on the -// snapshot so a resumed or background task can report how the captured turn -// or invocation ended. +// snapshotTurnEnd persists a turn-end snapshot capturing the committed +// session state, chained off the previous snapshot via ParentID, and +// returns its ID. It is a no-op returning "" when no store is configured +// or snapshots have been suspended by a detach. finishReason records how +// the captured turn ended so a resumed task can report it. +// +// The turn-end snapshot is the agent's only routine persistence point: a +// failed turn never writes one (its partial state is not a resume point), +// so the newest snapshot is always the last successful turn, which is what +// the failed and detached outputs resume from. // // The body runs under snapMu so the detach handler's suspend-and-capture // (suspendSnapshots) cannot interleave with a write: it either waits for -// this write to commit or suspends before it starts. -func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason) string { +// this write to commit or suspends before it starts. Persistence is +// best-effort: a store failure must not kill the in-flight turn, so it is +// logged and "" is returned. +func (s *SessionRunner[State]) snapshotTurnEnd(ctx context.Context, finishReason AgentFinishReason) string { if s.store == nil { return "" } @@ -293,107 +289,32 @@ func (s *SessionRunner[State]) maybeSnapshot(ctx context.Context, event Snapshot } s.mu.RLock() - currentVersion := s.version - currentState := s.copyStateLocked() + state := s.copyStateLocked() s.mu.RUnlock() - // Skip only if this snapshot would be identical to the last one: same - // state AND same finish reason. This dedups the common invocation-end - // snapshot after a single-turn Run (the turn-end snapshot already - // captured the same state and reason), but still writes when the - // invocation reports a different reason than the last turn (e.g. a - // custom agent overrode it on its AgentResult) — that snapshot is not - // redundant, it carries a new reason. - if s.lastSnapshot != nil && - currentVersion == s.lastSnapshotVersion && - finishReason == s.lastSnapshot.FinishReason { - return "" - } - - if s.snapshotCallback != nil { - var prevState *SessionState[State] - if s.lastSnapshot != nil { - prevState = s.lastSnapshot.State - } - if !s.snapshotCallback(ctx, &SnapshotContext[State]{ - State: ¤tState, - PrevState: prevState, - TurnIndex: s.TurnIndex, - Event: event, - }) { - return "" - } - } - - return s.persistSnapshotLocked(ctx, event, finishReason, ¤tState, currentVersion) -} - -// persistSnapshotLocked writes a completed snapshot row capturing state (at -// the given session version), chained to the newest persisted snapshot, and -// advances the lastSnapshot bookkeeping. Both the routine cadence -// (maybeSnapshot) and the failure path (recoverySnapshotID) funnel through -// here so the row shape and bookkeeping live in one place. Caller must hold -// snapMu. Persistence is best-effort: a store failure must not kill the -// in-flight turn, so it is logged and "" is returned. -func (s *SessionRunner[State]) persistSnapshotLocked(ctx context.Context, event SnapshotEvent, finishReason AgentFinishReason, state *SessionState[State], version uint64) string { - parentID := s.parentSnapshotID() + parentID := s.lastSnapshotID sessionID := s.SessionID() - saved, err := s.store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { return &SessionSnapshot[State]{ SessionID: sessionID, ParentID: parentID, - Event: event, Status: SnapshotStatusCompleted, FinishReason: finishReason, - State: state, + State: &state, }, nil }) if err != nil { logger.FromContext(ctx).Error("agent: failed to save snapshot", "parentId", parentID, - "event", event, "err", err) return "" } - s.lastSnapshot = saved - s.lastSnapshotVersion = version + s.lastSnapshotID = saved.SnapshotID return saved.SnapshotID } -// recoverySnapshotID returns the ID of a snapshot holding the last-good -// state, writing one (event [SnapshotEventRecovery]) when the newest -// persisted snapshot is behind it. The write uses the captured -// lastGoodState, never the live state (which may hold the failed turn's -// partial mutations), and intentionally bypasses both the snapshot -// callback and the post-detach suspension, so neither a selective cadence -// nor a dying detach can lose the conversation. If the write fails, the -// newest persisted snapshot's ID is returned instead. -// -// Returns "" when no store is configured or there is nothing to recover -// (no snapshot exists and no turn ever changed state). -func (s *SessionRunner[State]) recoverySnapshotID(ctx context.Context) string { - if s.store == nil { - return "" - } - s.snapMu.Lock() - defer s.snapMu.Unlock() - // The newest snapshot already captures exactly the last-good state. - if s.lastSnapshot != nil && s.lastGoodVersion == s.lastSnapshotVersion { - return s.lastSnapshot.SnapshotID - } - if s.lastSnapshot == nil && s.lastGoodVersion == 0 { - return "" - } - - if id := s.persistSnapshotLocked(ctx, SnapshotEventRecovery, s.lastGoodFinishReason, s.lastGoodState, s.lastGoodVersion); id != "" { - return id - } - return s.parentSnapshotID() -} - // --- Responder --- // Responder is the output channel for an agent. Artifacts sent through @@ -616,7 +537,7 @@ func (a *Agent[State]) StreamBidiJSON(ctx context.Context, opts *api.BidiSession // loaded from a .prompt file). // // State is inferred from the typed agent options (e.g. -// [WithSessionStore], [WithSnapshotOn]); pass an explicit [State] only +// [WithSessionStore], [WithStateTransform]); pass an explicit [State] only // when no typed option is provided. A typed option that disagrees with // the inferred State fails at compile time. // @@ -856,10 +777,13 @@ func newAgentRuntime[State any]( } rt.sess = &SessionRunner[State]{ - Session: session, - InputCh: rt.intake.out(), - snapshotCallback: cfg.callback, - lastSnapshot: parent, + Session: session, + InputCh: rt.intake.out(), + } + if parent != nil { + // Resumed: chain the first turn's snapshot off the one we loaded, and + // make it the resume point a first-turn failure falls back to. + rt.sess.lastSnapshotID = parent.SnapshotID } rt.sess.collectTurnOutput = func() any { return rt.router.collectTurnChunks() } rt.sess.onEndTurn = rt.emitTurnEnd @@ -872,9 +796,9 @@ func newAgentRuntime[State any]( firstInTurn: true, } rt.sess.onStartTurn = rt.patcher.beginTurn - // The initial state (fresh, client-provided, or loaded from a - // snapshot) is the last-good recovery point until a turn completes. - rt.sess.recordLastGood() + // The initial state (fresh, client-provided, or loaded from a snapshot) + // is the client-managed failure fallback until a turn completes. + rt.sess.captureLastGood() return rt, nil } @@ -890,7 +814,7 @@ func newAgentRuntime[State any]( // runs, so there is no in-flight router work to wait out. // // The snapshot is skipped when the turn failed (the live state holds the -// turn's partial mutations) and when detach has landed (maybeSnapshot +// turn's partial mutations) and when detach has landed (snapshotTurnEnd // observes the suspension under snapMu; the pending row already captures // the invocation and a single finalize rewrite records the cumulative // state once the queued inputs drain). @@ -899,7 +823,7 @@ func (rt *agentRuntime[State]) emitTurnEnd(ctx context.Context) { reason := rt.sess.lastTurnFinishReason var snapshotID string if !rt.sess.lastTurnFailed { - snapshotID = rt.sess.maybeSnapshot(ctx, SnapshotEventTurnEnd, reason) + snapshotID = rt.sess.snapshotTurnEnd(ctx, reason) } rt.router.sendChunk(ctx, &AgentStreamChunk{TurnEnd: &TurnEnd{ SnapshotID: snapshotID, @@ -1027,10 +951,11 @@ func (rt *agentRuntime[State]) drainAndWait(cancelWork context.CancelFunc) fnDon } // handleFnDone is the synchronous-completion path: fn returned before any -// detach signal. Capture an invocation-end snapshot if state advanced past -// the last turn-end snapshot, then assemble the output. When fn returned -// an error, the invocation resolves gracefully as a failed output instead -// (see failedOutput). +// detach signal. The output reports the last turn-end snapshot as its +// resume point; there is no separate invocation-end write, so state a +// custom agent mutates after its turn loop rides on the returned output +// but is not persisted. When fn returned an error, the invocation resolves +// gracefully as a failed output instead (see failedOutput). // // router.close blocks on the forward goroutine exiting, and fn returning // does not imply the router is idle: fn's last accepted chunk may still be @@ -1066,22 +991,14 @@ func (rt *agentRuntime[State]) handleFnDone( return rt.failedOutput(ctx, res.err), nil } - invocationReason := rt.sess.invocationReason(res.result) - snapshotID := rt.sess.maybeSnapshot(ctx, SnapshotEventInvocationEnd, invocationReason) - if snapshotID == "" && rt.sess.lastSnapshot != nil { - // No new row was written; reuse the last snapshot so the response - // always carries an ID when a store is configured. On the dedup path - // the reused row is genuinely identical (same state and reason). If - // the snapshot callback declined the write or the save failed, the - // reused row is the last turn-end snapshot, whose reason (and state) - // may lag what this output reports. - snapshotID = rt.sess.lastSnapshot.SnapshotID - } - + // The resume point is the last turn-end snapshot (lastSnapshotID), or "" + // when no store is configured or no turn committed. A custom agent that + // overrode the invocation's finish reason on its AgentResult sees it on + // the output below, but the snapshot keeps the turn's own reason. out := &AgentOutput[State]{ SessionID: rt.session.SessionID(), - SnapshotID: snapshotID, - FinishReason: invocationReason, + SnapshotID: rt.sess.lastSnapshotID, + FinishReason: rt.sess.invocationReason(res.result), } if res.result != nil { // Deep-copy at the framework boundary so the caller cannot @@ -1111,20 +1028,25 @@ func (rt *agentRuntime[State]) outboundState(ctx context.Context, state *Session // failedOutput assembles the output for an invocation that ended in // failure: [AgentFinishReasonFailed], the error with its original status, -// and the last-good state (inline when client-managed, behind a recovery -// snapshot ID when server-managed). Message and Artifacts are left empty; -// they describe the result of a completed run. +// and the last-good resume point: the last turn-end snapshot's ID when +// server-managed, or the last-good state inline when client-managed. Both +// hold the state through the last successful turn, excluding the failed +// turn's partial mutations, because a failed turn never snapshots and never +// updates lastGoodState. When no turn committed, the server-managed ID is +// "" (or the resumed snapshot's ID) and the client-managed state is the +// initial state. Message and Artifacts are left empty; they describe the +// result of a completed run. func (rt *agentRuntime[State]) failedOutput(ctx context.Context, cause error) *AgentOutput[State] { out := &AgentOutput[State]{ + SessionID: rt.session.SessionID(), FinishReason: AgentFinishReasonFailed, Error: core.AsGenkitError(cause), } if rt.cfg.store == nil { out.State = rt.outboundState(ctx, rt.sess.lastGoodState) } else { - out.SnapshotID = rt.sess.recoverySnapshotID(ctx) + out.SnapshotID = rt.sess.lastSnapshotID } - out.SessionID = rt.session.SessionID() return out } @@ -1159,7 +1081,6 @@ func (rt *agentRuntime[State]) handleDetach( return &SessionSnapshot[State]{ SessionID: sessionID, ParentID: parentID, - Event: SnapshotEventDetach, Status: SnapshotStatusPending, }, nil }) @@ -1265,7 +1186,6 @@ func (rt *agentRuntime[State]) finalizePendingSnapshot( return &SessionSnapshot[State]{ SessionID: pending.SessionID, ParentID: pending.ParentID, - Event: SnapshotEventDetach, Status: status, FinishReason: finishReason, Error: snapErr, diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 393bd4b165..e0f7435e6f 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -561,8 +561,8 @@ func TestAgent_SendArtifact_SynchronousAndCloned(t *testing.T) { // TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact verifies that a // turn-end snapshot captures artifacts sent during that turn: the Send // side effect applies before the call returns, so the snapshot taken at -// turn end cannot miss it. With snapshots restricted to turn end, the -// invocation output reuses the turn-end row, which therefore must hold +// turn end cannot miss it. Turn end is the agent's only snapshot point, so +// the invocation output reuses the turn-end row, which therefore must hold // the artifact for a later resume. func TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact(t *testing.T) { ctx := context.Background() @@ -580,7 +580,6 @@ func TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact(t *testing.T) { }) }, WithSessionStore[testState](store), - WithSnapshotOn[testState](SnapshotEventTurnEnd), ) out, err := af.RunText(ctx, "produce the report") @@ -597,9 +596,6 @@ func TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact(t *testing.T) { if snap == nil { t.Fatalf("snapshot %q not found", out.SnapshotID) } - if snap.Event != SnapshotEventTurnEnd { - t.Errorf("snapshot event = %q, want %q", snap.Event, SnapshotEventTurnEnd) - } if snap.State == nil || len(snap.State.Artifacts) != 1 { t.Fatalf("turn-end snapshot missing the artifact sent during the turn: %+v", snap.State) } @@ -608,65 +604,6 @@ func TestAgent_TurnEndSnapshot_IncludesSameTurnArtifact(t *testing.T) { } } -func TestAgent_SnapshotCallback(t *testing.T) { - ctx := context.Background() - reg := newTestRegistry(t) - store := newTestInMemStore[testState]() - - // Only snapshot on even turns. - callbackCalls := 0 - af := DefineCustomAgent(reg, "callbackFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.AddMessages(ai.NewModelTextMessage("reply")) - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - WithSessionStore(store), - WithSnapshotCallback(func(ctx context.Context, sc *SnapshotContext[testState]) bool { - callbackCalls++ - return sc.TurnIndex%2 == 0 // only snapshot on even turns - }), - ) - - conn, err := af.StreamBidi(ctx) - if err != nil { - t.Fatalf("StreamBidi failed: %v", err) - } - - var snapshotIDs []string - for i := 0; i < 3; i++ { - conn.SendText(fmt.Sprintf("turn %d", i)) - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error on turn %d: %v", i, err) - } - if chunk.TurnEnd != nil { - if chunk.TurnEnd.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) - } - break - } - } - } - conn.Close() - conn.Output() // drain - - // Turn 0 (even) → snapshot, Turn 1 (odd) → no, Turn 2 (even) → snapshot. - // That's 2 turn snapshots from the callback. - if got := len(snapshotIDs); got != 2 { - t.Errorf("expected 2 turn snapshots, got %d", got) - } - // Callback should have been called 3 times (once per turn). - if callbackCalls < 3 { - t.Errorf("expected at least 3 callback calls, got %d", callbackCalls) - } -} - func TestAgent_SendMessage(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) @@ -927,107 +864,10 @@ func TestAgent_FailedTurn_LastGoodStateIsResumable(t *testing.T) { } } -func TestAgent_FailedTurn_RecoverySnapshotBypassesCallback(t *testing.T) { - // Server-managed agent with a selective snapshot callback that only - // persists the first turn. The second (successful) turn is skipped by - // the callback, so when the third turn fails, the runtime must write a - // retroactive recovery snapshot of the last-good state — bypassing the - // callback — or the skipped turn would be lost. - ctx := context.Background() - reg := newTestRegistry(t) - store := newTestInMemStore[testState]() - - // The callback runs on the fn goroutine; the assertions below run - // after Output() returns, which happens-after fn completes, so no - // locking is needed. - var cbEvents []SnapshotEvent - af := defineLastGoodTestAgent(reg, "recoverySnapshot", - WithSessionStore[testState](store), - WithSnapshotCallback(func(_ context.Context, sc *SnapshotContext[testState]) bool { - cbEvents = append(cbEvents, sc.Event) - return sc.Event == SnapshotEventTurnEnd && sc.TurnIndex == 0 - }), - ) - - conn, err := af.StreamBidi(ctx) - if err != nil { - t.Fatalf("StreamBidi: %v", err) - } - - var turnEnds []*TurnEnd - for _, text := range []string{"one", "two"} { - if err := conn.SendText(text); err != nil { - t.Fatalf("SendText(%q): %v", text, err) - } - turnEnds = append(turnEnds, nextTurnEnd(t, conn)) - } - if turnEnds[0].SnapshotID == "" { - t.Fatal("expected turn 0 snapshot to be persisted") - } - if turnEnds[1].SnapshotID != "" { - t.Fatalf("expected turn 1 snapshot to be skipped by callback, got %q", turnEnds[1].SnapshotID) - } - - if err := conn.SendText("boom"); err != nil { - t.Fatalf("SendText(boom): %v", err) - } - out, err := conn.Output() - if err != nil { - t.Fatalf("Output: %v", err) - } - if out.FinishReason != AgentFinishReasonFailed { - t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) - } - if out.Error == nil || out.Error.Status != core.UNAVAILABLE { - t.Fatalf("expected error with status %q, got %+v", core.UNAVAILABLE, out.Error) - } - if out.State != nil { - t.Errorf("server-managed failed output must not carry inline state, got %+v", out.State) - } - if out.SnapshotID == "" { - t.Fatal("expected recovery snapshot ID on failed output") - } - if out.SnapshotID == turnEnds[0].SnapshotID { - t.Fatal("expected a fresh recovery snapshot, got the turn-0 snapshot") - } - - snap, err := store.GetSnapshot(ctx, out.SnapshotID) - if err != nil || snap == nil { - t.Fatalf("GetSnapshot(%q): %v, %v", out.SnapshotID, snap, err) - } - if snap.Status != SnapshotStatusCompleted { - t.Errorf("expected recovery snapshot status %q, got %q", SnapshotStatusCompleted, snap.Status) - } - if snap.Event != SnapshotEventRecovery { - t.Errorf("expected recovery snapshot event %q, got %q", SnapshotEventRecovery, snap.Event) - } - if snap.FinishReason != AgentFinishReasonStop { - t.Errorf("expected recovery snapshot to carry the last good turn's reason %q, got %q", - AgentFinishReasonStop, snap.FinishReason) - } - if snap.ParentID != turnEnds[0].SnapshotID { - t.Errorf("expected recovery snapshot parent %q, got %q", turnEnds[0].SnapshotID, snap.ParentID) - } - // State through the last successful turn, excluding the failed turn. - if got := len(snap.State.Messages); got != 4 { - t.Fatalf("expected 4 messages in recovery snapshot, got %d", got) - } - if got := snap.State.Custom.Counter; got != 2 { - t.Errorf("expected counter=2 in recovery snapshot, got %d", got) - } - - // The recovery write bypassed the callback: it was consulted once per - // successful turn only (the failed turn and the recovery write never - // ask). - if want := []SnapshotEvent{SnapshotEventTurnEnd, SnapshotEventTurnEnd}; !slices.Equal(cbEvents, want) { - t.Errorf("expected callback events %v, got %v", want, cbEvents) - } -} - -func TestAgent_FailedTurn_LastGoodAlreadyPersisted_NoRecoveryWrite(t *testing.T) { - // With the default always-snapshot cadence, the last-good state is - // already in the store when a turn fails: the output reuses that - // snapshot's ID and no extra row is written. +func TestAgent_FailedTurn_ServerManagedReturnsLastTurnSnapshot(t *testing.T) { + // Server-managed: every successful turn snapshots, so when a later turn + // fails the last-good state is already the newest row. The failed output + // reuses that turn's snapshot ID and no extra row is written. ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -1076,7 +916,6 @@ func TestAgent_FailedFirstTurn_AfterResume_ReturnsParentSnapshotID(t *testing.T) parent, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ - Event: SnapshotEventInvocationEnd, Status: SnapshotStatusCompleted, State: &SessionState[testState]{ Messages: []*ai.Message{ @@ -2157,12 +1996,12 @@ func TestPromptAgent_RejectsToolResponsePart(t *testing.T) { } } -func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { +func TestAgent_SingleTurnSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "dedupFlow", + af := DefineCustomAgent(reg, "singleTurnFlow", func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) @@ -2176,8 +2015,9 @@ func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { WithSessionStore(store), ) - // Single-turn invocation: should produce exactly 1 snapshot (turn-end), - // not 2 (turn-end + invocation-end with identical state). + // Single-turn invocation: exactly 1 snapshot (the turn-end), which the + // output reuses as its resume point. There is no second invocation-end + // write. response, err := af.RunText(ctx, "hello") if err != nil { t.Fatalf("RunText failed: %v", err) @@ -2186,22 +2026,21 @@ func TestAgent_SingleTurnSnapshotDedup(t *testing.T) { if response.SnapshotID == "" { t.Fatal("expected snapshot ID in response") } + if rows := store.snapshotCount(); rows != 1 { + t.Errorf("expected exactly 1 snapshot (turn-end only), got %d", rows) + } - // Count total snapshots in the store. snap, err := store.GetSnapshot(ctx, response.SnapshotID) if err != nil { t.Fatalf("GetSnapshot failed: %v", err) } - if snap.Event != SnapshotEventTurnEnd { - t.Errorf("expected turn-end snapshot, got %s", snap.Event) - } // The turn-end snapshot should have no parent (first and only snapshot). if snap.ParentID != "" { t.Errorf("expected no parent (single snapshot), got parent %q", snap.ParentID) } } -func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { +func TestAgent_MultiTurnSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -2220,7 +2059,7 @@ func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { WithSessionStore(store), ) - // Multi-turn: last turn-end snapshot should dedup with invocation-end. + // Multi-turn: one snapshot per turn; the output reuses the last one. conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) @@ -2248,7 +2087,7 @@ func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { t.Fatalf("Output failed: %v", err) } - // Should have 3 turn-end snapshots (one per turn), no extra invocation-end. + // Should have 3 turn-end snapshots, one per turn. if got := len(snapshotIDs); got != 3 { t.Errorf("expected 3 turn-end snapshots, got %d", got) } @@ -2263,7 +2102,11 @@ func TestAgent_MultiTurnSnapshotDedup(t *testing.T) { } } -func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { +func TestAgent_PostRunMutationNotSnapshotted(t *testing.T) { + // Snapshots happen at turn end only. State a custom agent mutates after + // its turn loop rides on the returned output but is not persisted: the + // output's snapshot ID is the last turn-end row, which predates the + // mutation, and no extra row is written. ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -2272,12 +2115,16 @@ func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter = 1 + return s + }) return nil, nil }); err != nil { return nil, err } - // Mutate state AFTER sess.Run returns -- this should trigger - // a separate invocation-end snapshot. + // Mutate state after sess.Run returns: rides on the output but is + // not snapshotted. sess.UpdateCustom(func(s testState) testState { s.Counter = 99 return s @@ -2291,26 +2138,22 @@ func TestAgent_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { if err != nil { t.Fatalf("RunText failed: %v", err) } - if response.SnapshotID == "" { t.Fatal("expected snapshot ID in response") } + // Exactly one snapshot (the turn-end); the post-loop mutation wrote none. + if rows := store.snapshotCount(); rows != 1 { + t.Errorf("expected exactly 1 snapshot (no post-loop write), got %d", rows) + } - // The final snapshot should be an invocation-end snapshot that captured - // the post-Run mutation. snap, err := store.GetSnapshot(ctx, response.SnapshotID) if err != nil { t.Fatalf("GetSnapshot failed: %v", err) } - if snap.Event != SnapshotEventInvocationEnd { - t.Errorf("expected invocation-end snapshot, got %s", snap.Event) - } - if snap.State.Custom.Counter != 99 { - t.Errorf("expected counter=99 in final snapshot, got %d", snap.State.Custom.Counter) - } - // Should have a parent (the turn-end snapshot). - if snap.ParentID == "" { - t.Error("expected parent ID (turn-end snapshot)") + // The snapshot holds the turn-end state (counter=1), not the post-loop + // mutation (counter=99). + if snap.State.Custom.Counter != 1 { + t.Errorf("expected turn-end counter=1 in snapshot, got %d", snap.State.Custom.Counter) } } @@ -3122,9 +2965,6 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { if snap.Status != SnapshotStatusCompleted { t.Errorf("turn-end snapshot status = %q, want completed", snap.Status) } - if snap.Event != SnapshotEventTurnEnd { - t.Errorf("turn-end snapshot event = %q, want %q", snap.Event, SnapshotEventTurnEnd) - } } func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { @@ -3190,7 +3030,6 @@ func TestAgent_ResumeFromErrorSnapshot_Rejected(t *testing.T) { if _, err := store.SaveSnapshot(context.Background(), erroredID, func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ - Event: SnapshotEventInvocationEnd, Status: SnapshotStatusFailed, Error: &core.GenkitError{ Status: core.INTERNAL, @@ -3421,7 +3260,6 @@ func TestAgent_GetSnapshotAction_BySessionID(t *testing.T) { return &SessionSnapshot[testState]{ SessionID: out1.SessionID, ParentID: out2.SnapshotID, - Event: SnapshotEventDetach, Status: SnapshotStatusFailed, FinishReason: AgentFinishReasonFailed, }, nil @@ -3490,7 +3328,6 @@ func TestLoadSession_AgentInitValidation(t *testing.T) { saved, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ SessionID: "sess-1", - Event: SnapshotEventInvocationEnd, State: state, }, nil }) @@ -4184,7 +4021,6 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { pending, err := store.SaveSnapshot(ctx, "snap-cas", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ - Event: SnapshotEventDetach, Status: SnapshotStatusPending, }, nil }) @@ -4228,7 +4064,6 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { if _, err := store.SaveSnapshot(ctx, "snap-complete", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ - Event: SnapshotEventTurnEnd, Status: SnapshotStatusCompleted, }, nil }); err != nil { @@ -4329,7 +4164,6 @@ func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { if _, err := store.SaveSnapshot(ctx, "snap-sub", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ - Event: SnapshotEventDetach, Status: SnapshotStatusPending, }, nil }); err != nil { @@ -4983,16 +4817,16 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { }) } -// TestAgent_FinishReason_InvocationOverride_Persisted verifies that when a +// TestAgent_FinishReason_InvocationOverride_OutputOnly verifies that when a // custom agent overrides the invocation reason (differing from the last -// turn's), the dedup does not collapse it onto the turn-end snapshot: a -// distinct invocation-end snapshot is written carrying the override, so the -// snapshot AgentOutput points at agrees with AgentOutput.FinishReason. -func TestAgent_FinishReason_InvocationOverride_Persisted(t *testing.T) { +// turn's), the override rides on AgentOutput.FinishReason but is not +// persisted: with no invocation-end snapshot, the output points at the last +// turn-end row, which keeps that turn's own reason. +func TestAgent_FinishReason_InvocationOverride_OutputOnly(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "overridePersistedFlow", + af := DefineCustomAgent(reg, "overrideOutputFlow", func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { if err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) @@ -5012,23 +4846,24 @@ func TestAgent_FinishReason_InvocationOverride_Persisted(t *testing.T) { if err != nil { t.Fatalf("RunText: %v", err) } + // The override is reported on the output... if out.FinishReason != AgentFinishReasonOther { t.Fatalf("AgentOutput.FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonOther) } if out.SnapshotID == "" { t.Fatal("expected a snapshot ID") } + // ...but only the turn-end snapshot exists, and no extra row was written + // for the override. + if rows := store.snapshotCount(); rows != 1 { + t.Errorf("expected exactly 1 snapshot (turn-end only), got %d", rows) + } snap, err := store.GetSnapshot(ctx, out.SnapshotID) if err != nil { t.Fatalf("GetSnapshot: %v", err) } - if snap.FinishReason != out.FinishReason { - t.Errorf("persisted snapshot.FinishReason = %q, want %q (must agree with AgentOutput)", snap.FinishReason, out.FinishReason) - } - // The override must be a fresh invocation-end snapshot, not the turn-end - // row mutated or reused: the divergent reason busts the dedup. - if snap.Event != SnapshotEventInvocationEnd { - t.Errorf("snapshot.Event = %q, want %q (a distinct invocation-end snapshot)", snap.Event, SnapshotEventInvocationEnd) + if snap.FinishReason != AgentFinishReasonStop { + t.Errorf("snapshot.FinishReason = %q, want %q (the turn's own reason, not the invocation override)", snap.FinishReason, AgentFinishReasonStop) } } @@ -5315,9 +5150,9 @@ func TestAgent_SessionID_AssignedAndStable(t *testing.T) { func TestAgent_SessionID_AssignedBeforeFirstSnapshot(t *testing.T) { // The session ID exists from invocation start, not from the first - // snapshot write: an invocation whose callback declines every write - // still reports the session it belongs to and exposes it to the agent - // fn, with no snapshot to show for it yet. + // snapshot write: an invocation that commits no turn (and so writes no + // snapshot) still reports the session it belongs to and exposes it to + // the agent fn, with no snapshot to show for it yet. ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -5331,13 +5166,11 @@ func TestAgent_SessionID_AssignedBeforeFirstSnapshot(t *testing.T) { if s := SessionFromContext[testState](ctx); s != nil { ctxSawSessionID = s.SessionID() } - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.AddMessages(ai.NewModelTextMessage("reply")) - return nil, nil - }) + // Return without running a turn: nothing is committed, so no + // snapshot is written, yet the session ID is already settled. + return nil, nil }, WithSessionStore(store), - WithSnapshotCallback(func(context.Context, *SnapshotContext[testState]) bool { return false }), ) out, err := af.RunText(ctx, "hi") @@ -5357,7 +5190,7 @@ func TestAgent_SessionID_AssignedBeforeFirstSnapshot(t *testing.T) { t.Errorf("context-carried session saw ID %q, output reports %q", ctxSawSessionID, out.SessionID) } if out.SnapshotID != "" { - t.Errorf("expected no snapshot (callback declined every write), got %q", out.SnapshotID) + t.Errorf("expected no snapshot (no turn committed), got %q", out.SnapshotID) } // A session with no persisted snapshots is not resumable, but supplying @@ -5512,7 +5345,6 @@ func TestAgent_ResumeFromSessionID_FailedTipRejected(t *testing.T) { return &SessionSnapshot[testState]{ SessionID: out1.SessionID, ParentID: out1.SnapshotID, - Event: SnapshotEventDetach, Status: SnapshotStatusFailed, FinishReason: AgentFinishReasonFailed, }, nil @@ -5595,22 +5427,15 @@ func TestAgent_ResumeFromSessionID_NewConversation(t *testing.T) { } } -func TestAgent_ResumeFromSessionID_AfterFailureResumesRecovery(t *testing.T) { - // After an invocation fails, the session's newest non-dead-end row is - // the recovery snapshot (a normal completed row, event=recovery) - // holding the last-good state. Resuming by session ID continues from - // it like any other snapshot. +func TestAgent_ResumeFromSessionID_AfterFailureResumesLastTurn(t *testing.T) { + // After an invocation fails, the session's newest row is the last + // successful turn's snapshot (a failed turn writes none), holding the + // last-good state. Resuming by session ID continues from it like any + // other snapshot. ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := defineLastGoodTestAgent(reg, "sessionRecoveryFlow", - WithSessionStore[testState](store), - // Persist only the first turn so the failure path must write a - // genuine recovery-event row (rather than reusing a turn-end row). - WithSnapshotCallback(func(_ context.Context, sc *SnapshotContext[testState]) bool { - return sc.Event == SnapshotEventTurnEnd && sc.TurnIndex == 0 - }), - ) + af := defineLastGoodTestAgent(reg, "sessionRecoveryFlow", WithSessionStore[testState](store)) conn, err := af.StreamBidi(ctx) if err != nil { @@ -5628,12 +5453,14 @@ func TestAgent_ResumeFromSessionID_AfterFailureResumesRecovery(t *testing.T) { if out.FinishReason != AgentFinishReasonFailed { t.Fatalf("expected failed invocation, got %q", out.FinishReason) } - recovery, err := store.GetSnapshot(ctx, out.SnapshotID) - if err != nil || recovery == nil { - t.Fatalf("GetSnapshot(%q): %v, %v", out.SnapshotID, recovery, err) + // The failed output points at the last successful turn's snapshot (turn + // "two", counter=2), not the failed "boom" turn. + lastGood, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil || lastGood == nil { + t.Fatalf("GetSnapshot(%q): %v, %v", out.SnapshotID, lastGood, err) } - if recovery.Event != SnapshotEventRecovery { - t.Fatalf("expected recovery snapshot, got event %q", recovery.Event) + if got := lastGood.State.Custom.Counter; got != 2 { + t.Fatalf("expected last-good snapshot counter=2, got %d", got) } out2, err := af.RunText(ctx, "three", WithSessionID[testState](out.SessionID)) @@ -5647,16 +5474,16 @@ func TestAgent_ResumeFromSessionID_AfterFailureResumesRecovery(t *testing.T) { if err != nil { t.Fatalf("GetSnapshot: %v", err) } - if snap2.ParentID != recovery.SnapshotID { - t.Errorf("resumed snapshot ParentID = %q, want recovery row %q", snap2.ParentID, recovery.SnapshotID) + if snap2.ParentID != lastGood.SnapshotID { + t.Errorf("resumed snapshot ParentID = %q, want last-good row %q", snap2.ParentID, lastGood.SnapshotID) } // Last-good state (two successful turns, counter=2) plus the resumed // turn: the failed turn's partial mutations never made it in. if got := snap2.State.Custom.Counter; got != 3 { - t.Errorf("expected counter=3 after resuming recovery state, got %d", got) + t.Errorf("expected counter=3 after resuming last-good state, got %d", got) } if got := len(snap2.State.Messages); got != 6 { - t.Errorf("expected 6 messages after resuming recovery state, got %d", got) + t.Errorf("expected 6 messages after resuming last-good state, got %d", got) } } @@ -5677,7 +5504,6 @@ func TestAgent_ResumeFromSessionID_PendingTipRejected(t *testing.T) { return &SessionSnapshot[testState]{ SessionID: out1.SessionID, ParentID: out1.SnapshotID, - Event: SnapshotEventDetach, Status: SnapshotStatusPending, }, nil }); err != nil { @@ -6049,7 +5875,6 @@ func TestAgent_ResumeFromLegacySnapshot_MintsFreshSessionID(t *testing.T) { legacy := &SessionSnapshot[testState]{ SnapshotID: "legacy-1", - Event: SnapshotEventInvocationEnd, Status: SnapshotStatusCompleted, CreatedAt: time.Now(), UpdatedAt: time.Now(), diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 94ea1c2cab..89fab39a16 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -204,13 +204,12 @@ type AgentOutput[State any] struct { // invocation; only a session with persisted snapshots can be resumed by // this ID. SessionID string `json:"sessionId,omitempty"` - // SnapshotID is the ID of the newest snapshot capturing this invocation: - // the invocation-end snapshot, or the latest earlier snapshot when that - // write was skipped. Empty when no store is configured or the invocation - // persisted nothing. When FinishReason is [AgentFinishReasonDetached] it - // is the pending detach snapshot; when [AgentFinishReasonFailed], the most - // recent snapshot capturing the last-good state: everything through the - // last successful turn (see [SnapshotEventRecovery]). + // SnapshotID is the ID of the most recent turn-end snapshot for this + // invocation. Empty when no store is configured or no turn committed. When + // FinishReason is [AgentFinishReasonDetached] it is the pending detach + // snapshot; when [AgentFinishReasonFailed], it is the last committed turn's + // snapshot (the resume point, holding state through the last successful turn + // and excluding the failed turn's partial mutations). SnapshotID string `json:"snapshotId,omitempty"` // State contains the final conversation state. // Only populated when state is client-managed (no store configured). @@ -359,8 +358,6 @@ type SessionSnapshot[State any] struct { // Error is the structured failure information for a snapshot in // [SnapshotStatusFailed]. Nil otherwise. Error *core.GenkitError `json:"error,omitempty"` - // Event is what triggered this snapshot. - Event SnapshotEvent `json:"event"` // FinishReason is the semantic reason the turn or invocation captured here // ended (e.g. [AgentFinishReasonStop], [AgentFinishReasonInterrupted], // [AgentFinishReasonFailed], [AgentFinishReasonAborted]). It complements @@ -414,29 +411,6 @@ type SessionState[State any] struct { SessionID string `json:"sessionId,omitempty"` } -// SnapshotEvent identifies what triggered a snapshot. -type SnapshotEvent string - -const ( - // SnapshotEventTurnEnd indicates the snapshot was triggered at the end of a turn. - SnapshotEventTurnEnd SnapshotEvent = "turnEnd" - // SnapshotEventInvocationEnd indicates the snapshot was triggered at the end - // of the invocation. - SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" - // SnapshotEventDetach indicates the snapshot was created when the client - // detached the invocation and the work continues in the background. The - // snapshot is initially written with [SnapshotStatusPending] and rewritten - // with a terminal status once the background work finishes. - SnapshotEventDetach SnapshotEvent = "detach" - // SnapshotEventRecovery indicates the snapshot was written retroactively - // by the failure path to preserve the last-good state (everything through - // the last successful turn) when a selective snapshot callback had skipped - // persisting it. It is a normal [SnapshotStatusCompleted] row carrying the - // last good turn's finish reason, resumable like any other; the snapshot - // callback is bypassed and never sees this event. - SnapshotEventRecovery SnapshotEvent = "recovery" -) - // SnapshotStatus describes the lifecycle state of a snapshot. Snapshots // written for synchronous turns or invocations are always // [SnapshotStatusCompleted] (an empty value is also treated as completed @@ -482,8 +456,7 @@ type TurnEnd struct { // sends fail with [core.ErrActionCompleted]. FinishReason AgentFinishReason `json:"finishReason,omitempty"` // SnapshotID is the ID of the snapshot persisted at the end of this turn. - // Empty if no snapshot was written (no store configured, the callback - // declined, nothing changed since the last snapshot, or snapshots were - // suspended after detach). + // Empty if no snapshot was written (no store configured, the turn failed, or + // snapshots were suspended after detach). SnapshotID string `json:"snapshotId,omitempty"` } diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index ed321320fc..df5bd98789 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -459,7 +459,6 @@ func TestFileSessionStore_GetLatestSnapshot_SkipsUnparseableFiles(t *testing.T) func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { return &exp.SessionSnapshot[testState]{ SessionID: "sess-1", - Event: exp.SnapshotEventTurnEnd, Status: exp.SnapshotStatusCompleted, }, nil }); err != nil { diff --git a/go/ai/exp/localstore/store_test.go b/go/ai/exp/localstore/store_test.go index 24f76200e4..26bf83830b 100644 --- a/go/ai/exp/localstore/store_test.go +++ b/go/ai/exp/localstore/store_test.go @@ -44,7 +44,6 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio return &exp.SessionSnapshot[testState]{ SessionID: sessionID, ParentID: parentID, - Event: exp.SnapshotEventTurnEnd, Status: status, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, }, nil @@ -90,7 +89,6 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio return &exp.SessionSnapshot[testState]{ SessionID: rewrite, ParentID: existing.ParentID, - Event: existing.Event, Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, }, nil diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index f52945d12e..377443648f 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -48,7 +48,6 @@ type StateTransform[State any] = func(ctx context.Context, state *SessionState[S type agentOptions[State any] struct { store SessionStore[State] - callback SnapshotCallback[State] transform StateTransform[State] description string } @@ -60,12 +59,6 @@ func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { } opts.store = o.store } - if o.callback != nil { - if opts.callback != nil { - return errors.New("cannot set snapshot callback more than once (WithSnapshotCallback)") - } - opts.callback = o.callback - } if o.transform != nil { if opts.transform != nil { return errors.New("cannot set state transform more than once (WithStateTransform)") @@ -89,26 +82,6 @@ func WithSessionStore[State any](store SessionStore[State]) AgentOption[State] { return &agentOptions[State]{store: store} } -// WithSnapshotCallback configures when snapshots are created. -// If not provided and a store is configured, snapshots are always created. -func WithSnapshotCallback[State any](cb SnapshotCallback[State]) AgentOption[State] { - return &agentOptions[State]{callback: cb} -} - -// WithSnapshotOn configures snapshots to be created only for the specified events. -// For example, WithSnapshotOn[MyState](SnapshotEventTurnEnd) skips the -// invocation-end snapshot. -func WithSnapshotOn[State any](events ...SnapshotEvent) AgentOption[State] { - set := make(map[SnapshotEvent]struct{}, len(events)) - for _, e := range events { - set[e] = struct{}{} - } - return WithSnapshotCallback(func(_ context.Context, sc *SnapshotContext[State]) bool { - _, ok := set[sc.Event] - return ok - }) -} - // WithStateTransform registers a transform applied to session state on // its way out to a client via the getSnapshot companion action or via // [AgentOutput.State] when state is client-managed. Typical use is PII diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 4c9d92a50d..eed52e636d 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -30,22 +30,6 @@ import ( // --- Snapshot --- -// SnapshotContext provides context for snapshot decision callbacks. -type SnapshotContext[State any] struct { - // State is the current state that will be snapshotted if the callback returns true. - State *SessionState[State] - // PrevState is the state at the last snapshot, or nil if none exists. - PrevState *SessionState[State] - // TurnIndex is the turn number in the current invocation. - TurnIndex int - // Event is what triggered this snapshot check. - Event SnapshotEvent -} - -// SnapshotCallback decides whether to create a snapshot. -// If not provided and a store is configured, snapshots are always created. -type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool - // applyTransform returns the result of applying t to state, or state // unchanged if t is nil. A nil state is returned as-is. func applyTransform[State any](ctx context.Context, t StateTransform[State], state *SessionState[State]) *SessionState[State] { @@ -108,8 +92,8 @@ type SnapshotWriter[State any] interface { // - UpdatedAt: stamped to the wall clock on every commit. // - Status: if the snapshot returned by fn has Status="", it is // defaulted to [SnapshotStatusCompleted] (the common case for - // synchronous turn-end and invocation-end writes). Callers - // writing a pending row must set Status explicitly. + // synchronous turn-end writes). Callers writing a pending row must + // set Status explicitly. // // fn receives the existing snapshot (or nil if id is empty or the // row does not exist) and returns the snapshot to commit, or @@ -337,10 +321,9 @@ func newSnapshotActions[State any]( // Session holds conversation state and provides thread-safe read/write // access to messages, custom state, and artifacts. type Session[State any] struct { - mu sync.RWMutex - state SessionState[State] - store SessionStore[State] - version uint64 // incremented on every mutation; used to skip redundant snapshots + mu sync.RWMutex + state SessionState[State] + store SessionStore[State] // onCustomChange, when set by the agent runtime, is invoked after every // UpdateCustom mutation (outside the lock) so the runtime can emit a @@ -393,7 +376,6 @@ func (s *Session[State]) AddMessages(messages ...*ai.Message) { s.mu.Lock() defer s.mu.Unlock() s.state.Messages = append(s.state.Messages, messages...) - s.version++ } // SetMessages replaces the conversation history with the given messages. @@ -401,7 +383,6 @@ func (s *Session[State]) SetMessages(messages []*ai.Message) { s.mu.Lock() defer s.mu.Unlock() s.state.Messages = messages - s.version++ } // UpdateMessages atomically reads the current messages, applies the given @@ -412,7 +393,6 @@ func (s *Session[State]) UpdateMessages(fn func([]*ai.Message) []*ai.Message) { s.mu.Lock() defer s.mu.Unlock() s.state.Messages = fn(s.state.Messages) - s.version++ } // Custom returns the current user-defined custom state. @@ -445,7 +425,6 @@ func (s *Session[State]) customJSON() any { func (s *Session[State]) UpdateCustom(fn func(State) State) { s.mu.Lock() s.state.Custom = fn(s.state.Custom) - s.version++ s.mu.Unlock() // Emit the customPatch delta after releasing the lock: the hook reads // session state (and may send on the wire), neither of which is safe to @@ -487,7 +466,6 @@ func (s *Session[State]) AddArtifacts(artifacts ...*Artifact) { s.state.Artifacts = append(s.state.Artifacts, a) } } - s.version++ } // UpdateArtifacts atomically reads the current artifacts, applies the given @@ -498,7 +476,6 @@ func (s *Session[State]) UpdateArtifacts(fn func([]*Artifact) []*Artifact) { s.mu.Lock() defer s.mu.Unlock() s.state.Artifacts = fn(s.state.Artifacts) - s.version++ } // copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). diff --git a/go/core/schemas.config b/go/core/schemas.config index c2e72c6b7c..32a836db52 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1361,13 +1361,12 @@ this ID. . AgentOutput.snapshotId doc -SnapshotID is the ID of the newest snapshot capturing this invocation: -the invocation-end snapshot, or the latest earlier snapshot when that -write was skipped. Empty when no store is configured or the invocation -persisted nothing. When FinishReason is [AgentFinishReasonDetached] it -is the pending detach snapshot; when [AgentFinishReasonFailed], the most -recent snapshot capturing the last-good state: everything through the -last successful turn (see [SnapshotEventRecovery]). +SnapshotID is the ID of the most recent turn-end snapshot for this +invocation. Empty when no store is configured or no turn committed. When +FinishReason is [AgentFinishReasonDetached] it is the pending detach +snapshot; when [AgentFinishReasonFailed], it is the last committed turn's +snapshot (the resume point, holding state through the last successful turn +and excluding the failed turn's partial mutations). . AgentOutput.state doc @@ -1559,9 +1558,8 @@ snapshot was persisted. TurnEnd.snapshotId doc SnapshotID is the ID of the snapshot persisted at the end of this turn. -Empty if no snapshot was written (no store configured, the callback -declined, nothing changed since the last snapshot, or snapshots were -suspended after detach). +Empty if no snapshot was written (no store configured, the turn failed, or +snapshots were suspended after detach). . TurnEnd.finishReason doc @@ -1657,11 +1655,6 @@ it equals CreatedAt; once the snapshot is finalized it reflects the terminal write. . -SessionSnapshot.event noomitempty -SessionSnapshot.event doc -Event is what triggered this snapshot. -. - SessionSnapshot.status doc Status is the lifecycle state of this snapshot. Empty is treated as [SnapshotStatusCompleted] for backwards compatibility. @@ -1689,41 +1682,6 @@ invocation is still processing queued inputs); populated on terminal snapshots with the cumulative final state. . -# ---------------------------------------------------------------------------- -# SnapshotEvent -# ---------------------------------------------------------------------------- - -SnapshotEvent pkg ai/exp - -SnapshotEvent doc -SnapshotEvent identifies what triggered a snapshot. -. - -SnapshotEventTurnEnd doc -SnapshotEventTurnEnd indicates the snapshot was triggered at the end of a turn. -. - -SnapshotEventInvocationEnd doc -SnapshotEventInvocationEnd indicates the snapshot was triggered at the end -of the invocation. -. - -SnapshotEventDetach doc -SnapshotEventDetach indicates the snapshot was created when the client -detached the invocation and the work continues in the background. The -snapshot is initially written with [SnapshotStatusPending] and rewritten -with a terminal status once the background work finishes. -. - -SnapshotEventRecovery doc -SnapshotEventRecovery indicates the snapshot was written retroactively -by the failure path to preserve the last-good state (everything through -the last successful turn) when a selective snapshot callback had skipped -persisting it. It is a normal [SnapshotStatusCompleted] row carrying the -last good turn's finish reason, resumable like any other; the snapshot -callback is bypassed and never sees this event. -. - # ---------------------------------------------------------------------------- # Snapshot lifecycle types # ---------------------------------------------------------------------------- diff --git a/go/genkit/exp/routes_test.go b/go/genkit/exp/routes_test.go index 7d19a53a3b..af01488e93 100644 --- a/go/genkit/exp/routes_test.go +++ b/go/genkit/exp/routes_test.go @@ -67,7 +67,6 @@ func newRouteTestGenkit(t *testing.T) *genkit.Genkit { } genkit.DefineAgent(g, "serverChat", aix.FromInline(ai.WithModelName("test/echo")), aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) genkit.DefineAgent[any](g, "clientChat", aix.FromInline(ai.WithModelName("test/echo"))) genkit.DefineFlow(g, "greet", func(ctx context.Context, name string) (string, error) { diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index f626148a12..b541a1c91a 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -427,7 +427,7 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // loaded from a .prompt file). // // The State type parameter is inferred from the typed agent options -// (e.g. [aix.WithSessionStore], [aix.WithSnapshotOn]); pass an explicit +// (e.g. [aix.WithSessionStore], [aix.WithStateTransform]); pass an explicit // [State] only when no typed option is provided. // // The returned agent is an [api.BidiAction]; pass it to [Handler] to @@ -441,8 +441,6 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // # Options // // - [aix.WithSessionStore]: Enable snapshot persistence -// - [aix.WithSnapshotCallback]: Control when snapshots are created -// - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types // - [aix.WithStateTransform]: Rewrite session state on its way out to the client // // Example (inline prompt): @@ -496,8 +494,6 @@ func DefineAgent[State any]( // # Options // // - [aix.WithSessionStore]: Enable snapshot persistence -// - [aix.WithSnapshotCallback]: Control when snapshots are created -// - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types // - [aix.WithStateTransform]: Rewrite session state on its way out to the client // // The State type parameter is the shape of the conversation's custom state diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 792abe1a5a..08521a0056 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -766,7 +766,6 @@ func TestHandlerAgent(t *testing.T) { } DefineAgent(g, "agentServer", aix.FromInline(ai.WithModelName("test/echo")), aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) // Agents register under their own action type, so they surface through @@ -972,7 +971,6 @@ func TestHandlerAgentRef(t *testing.T) { } agent := DefineAgent(g, "agentRef", aix.FromInline(ai.WithModelName("test/echo")), aix.WithSessionStore(store), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), ) // Handlers come straight off the ref; no registry iteration involved. diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index 5e9c7820a9..e38e505d4d 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -115,7 +115,6 @@ func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any] { ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), ), aix.WithSessionStore(mustStore(name)), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), aix.WithDescription[any]("Sarcastic pirate (inline-defined prompt)"), ) } @@ -134,7 +133,6 @@ func definePromptAgent(g *genkit.Genkit) *aix.Agent[any] { return genkit.DefineAgent(g, name, aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), aix.WithSessionStore(mustStore(name)), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), aix.WithDescription[any]("Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)"), ) } @@ -183,7 +181,6 @@ func defineCustomAgent(g *genkit.Genkit) *aix.Agent[any] { return sess.Result(), nil }, aix.WithSessionStore(mustStore(name)), - aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), aix.WithDescription[any]("Concise code helper (custom per-turn loop)"), ) } From b5fc33338b2822815723cfc3c030f618a4913daa Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 17:51:45 -0700 Subject: [PATCH 110/141] refactor(go/exp): drop redundant snapshotId from AbortSnapshotResponse The abortSnapshot companion action's response echoed back the snapshot ID the caller already supplied in the request. Drop it; the response carries only the resulting Status, which is the useful outcome of the abort. Regenerated from agent.ts through genkit-schema.json and gen.go. --- genkit-tools/common/src/types/agent.ts | 2 -- genkit-tools/genkit-schema.json | 6 ------ go/ai/exp/gen.go | 2 -- go/ai/exp/session.go | 2 +- go/core/schemas.config | 5 ----- 5 files changed, 1 insertion(+), 16 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 46889518f7..9c53079c3e 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -428,8 +428,6 @@ export type AbortSnapshotRequest = z.infer; * Zod schema for the output of the `abortSnapshot` companion action. */ export const AbortSnapshotResponseSchema = z.object({ - /** Echoes the requested snapshot ID. */ - snapshotId: z.string(), /** * Snapshot's status after the abort attempt. For a pending snapshot * this is `aborted`. For an already-terminal snapshot this is the diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 114bb91950..3f54191bbb 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -16,16 +16,10 @@ "AbortSnapshotResponse": { "type": "object", "properties": { - "snapshotId": { - "type": "string" - }, "status": { "$ref": "#/$defs/SnapshotStatus" } }, - "required": [ - "snapshotId" - ], "additionalProperties": false }, "AgentFinishReason": { diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 89fab39a16..6634f509f2 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -32,8 +32,6 @@ type AbortSnapshotRequest struct { // AbortSnapshotResponse is the output of the abortSnapshot companion action. type AbortSnapshotResponse struct { - // SnapshotID echoes the requested snapshot ID. - SnapshotID string `json:"snapshotId"` // Status is the snapshot's status after the abort attempt. For a // pending snapshot this is [SnapshotStatusAborted]. For an // already-terminal snapshot this is the existing terminal status (the diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index eed52e636d..5915769a9f 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -311,7 +311,7 @@ func newSnapshotActions[State any]( if status == "" { return nil, core.NewError(core.NOT_FOUND, "abortSnapshot: snapshot %q not found", req.SnapshotID) } - return &AbortSnapshotResponse{SnapshotID: req.SnapshotID, Status: status}, nil + return &AbortSnapshotResponse{Status: status}, nil }) return getSnapshotAction, abortSnapshotAction } diff --git a/go/core/schemas.config b/go/core/schemas.config index 32a836db52..1cea9d070d 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1840,11 +1840,6 @@ AbortSnapshotResponse doc AbortSnapshotResponse is the output of the abortSnapshot companion action. . -AbortSnapshotResponse.snapshotId noomitempty -AbortSnapshotResponse.snapshotId doc -SnapshotID echoes the requested snapshot ID. -. - AbortSnapshotResponse.status doc Status is the snapshot's status after the abort attempt. For a pending snapshot this is [SnapshotStatusAborted]. For an From c5741b6672d02896532b8118b882a58e1afda6cc Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 17:53:55 -0700 Subject: [PATCH 111/141] chore(py): regenerate schema typing for agent schema changes The agent schema changes (removing SnapshotEvent and the SessionSnapshot event field, and AbortSnapshotResponse.snapshotId) regenerated genkit-schema.json but not the Python types derived from it, so CI's generate_schema_typing --ci check failed on the drift. Regenerate _typing.py to match. --- py/packages/genkit/src/genkit/_core/_typing.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 963c4352fd..d988fc55ef 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -66,15 +66,6 @@ class JsonPatchOp(StrEnum): TEST = 'test' -class SnapshotEvent(StrEnum): - """SnapshotEvent data type class.""" - - TURNEND = 'turnEnd' - INVOCATIONEND = 'invocationEnd' - DETACH = 'detach' - RECOVERY = 'recovery' - - class SnapshotStatus(StrEnum): """SnapshotStatus data type class.""" @@ -140,7 +131,6 @@ class AbortSnapshotResponse(GenkitModel): """Model for abortsnapshotresponse data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) - snapshot_id: str = Field(...) status: SnapshotStatus | None = None @@ -239,7 +229,6 @@ class SessionSnapshot(GenkitModel): parent_id: str | None = None created_at: str = Field(...) updated_at: str | None = None - event: SnapshotEvent = Field(...) status: SnapshotStatus | None = None finish_reason: AgentFinishReason | None = None error: GenkitRuntimeError | None = None From 0aea6796d7d3af1dbb166ea951af118a86586098 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 19:59:15 -0700 Subject: [PATCH 112/141] refactor(go/exp): replace Mount helper with inline route loop Mount was a thin loop over Route.Pattern/Route.Handler. Inlining it in the sample and test makes the wiring visible (the same range-and-mount any router needs) and trims an experimental helper. Route.Pattern and Route.Handler remain the building blocks; doc comments now point at them instead of Mount. --- go/genkit/exp/doc.go | 13 ++++++----- go/genkit/exp/routes.go | 31 ++++++++------------------ go/genkit/exp/routes_test.go | 15 ++++++++----- go/samples/basic-agents-server/main.go | 14 +++++++----- 4 files changed, 34 insertions(+), 39 deletions(-) diff --git a/go/genkit/exp/doc.go b/go/genkit/exp/doc.go index 7f2e6e4169..6a447405ae 100644 --- a/go/genkit/exp/doc.go +++ b/go/genkit/exp/doc.go @@ -18,12 +18,13 @@ Package exp holds experimental genkit helpers that are still taking shape. It currently provides: - - An HTTP route layout for serving agents and flows: the [Route] value, the - [AgentRoutes] / [AllAgentRoutes] / [FlowRoutes] / [AllFlowRoutes] builders, - and [Mount]. The handlers themselves come from the stable genkit package - ([genkit.Handler]); this package only lays out which paths map to which - actions, so the routing layer can evolve without touching genkit's stable - surface. + - An HTTP route layout for serving agents and flows: the [Route] value and + the [AgentRoutes] / [AllAgentRoutes] / [FlowRoutes] / [AllFlowRoutes] + builders. Range over the routes and wire each onto an [http.ServeMux] (or + any router) with [Route.Pattern] and [Route.Handler]. The handlers + themselves come from the stable genkit package ([genkit.Handler]); this + package only lays out which paths map to which actions, so the routing + layer can evolve without touching genkit's stable surface. - A channel-based streaming flow constructor, [DefineStreamingFlow]: an alternative to the callback-based [genkit.DefineStreamingFlow] for logic diff --git a/go/genkit/exp/routes.go b/go/genkit/exp/routes.go index 3878bc378b..011b149b8a 100644 --- a/go/genkit/exp/routes.go +++ b/go/genkit/exp/routes.go @@ -17,7 +17,6 @@ package exp import ( - "log/slog" "net/http" aix "github.com/firebase/genkit/go/ai/exp" @@ -35,9 +34,10 @@ const ( // method and path to mount and the action to serve. // // [AgentRoutes], [AllAgentRoutes], [FlowRoutes], and [AllFlowRoutes] -// produce Routes; [Mount] wires them onto an [http.ServeMux]. The fields -// are exported so other routers (Gin, Chi, Echo) can mount the same -// layout: read Method and Path and serve Action with [genkit.Handler]. +// produce Routes; range over them and wire each onto an [http.ServeMux] +// with [Route.Pattern] and [Route.Handler]. The fields are exported so +// other routers (Gin, Chi, Echo) can mount the same layout: read Method +// and Path and serve Action with [genkit.Handler]. // Every route is a POST that speaks the {"data": ...} / {"result": ...} // envelope of the reflection API (the agent turn route also streams via // ?stream=true), so a single client transport reaches all of them. @@ -63,7 +63,8 @@ func (r Route) Handler(opts ...genkit.HandlerOption) http.HandlerFunc { // AllAgentRoutes returns the default serving layout for every agent // registered with g, the iterate-over-all counterpart to [AgentRoutes]. -// Pass the result to [Mount], or to a router of your choice. See +// Mount the result onto an [http.ServeMux] (range over it and call +// [Route.Handler]) or hand it to a router of your choice. See // [AgentRoutes] for the per-agent layout and the route set each agent // contributes. func AllAgentRoutes(g *genkit.Genkit) []Route { @@ -85,8 +86,8 @@ func AllAgentRoutes(g *genkit.Genkit) []Route { } // AgentRoutes returns the default serving layout for a single agent, so you -// can mount specific agents rather than every registered one. Pass the -// result to [Mount], or to a router of your choice. +// can mount specific agents rather than every registered one. Mount the +// result onto an [http.ServeMux], or onto a router of your choice. // // The route set mirrors what the agent can do: // @@ -128,7 +129,7 @@ func buildAgentRoutes(name string, run, snapshot, abort api.Action) []Route { // AllFlowRoutes returns the default serving layout for every flow // registered with g, the iterate-over-all counterpart to [FlowRoutes]. -// Pass the result to [Mount], or to a router of your choice. +// Mount the result onto an [http.ServeMux], or onto a router of your choice. func AllFlowRoutes(g *genkit.Genkit) []Route { var routes []Route for _, f := range genkit.ListFlows(g) { @@ -148,17 +149,3 @@ func FlowRoutes(f api.Action) []Route { func buildFlowRoute(f api.Action) Route { return Route{Method: http.MethodPost, Path: flowBasePath + "/" + f.Name(), Action: f} } - -// Mount registers routes on mux, building each route's handler with opts -// (e.g. [genkit.WithContextProviders] shared across all of them). It is -// the stdlib convenience over [Route.Handler]; routers other than -// [http.ServeMux] can range over the routes and mount them directly. -// -// Compose layouts by concatenating: Mount(mux, append(AllAgentRoutes(g), -// AllFlowRoutes(g)...), opts...). -func Mount(mux *http.ServeMux, routes []Route, opts ...genkit.HandlerOption) { - for _, rt := range routes { - mux.HandleFunc(rt.Pattern(), rt.Handler(opts...)) - slog.Debug("genkit/exp.Mount", "method", rt.Method, "path", rt.Path) - } -} diff --git a/go/genkit/exp/routes_test.go b/go/genkit/exp/routes_test.go index af01488e93..37adea95df 100644 --- a/go/genkit/exp/routes_test.go +++ b/go/genkit/exp/routes_test.go @@ -144,15 +144,18 @@ func TestAllFlowRoutes(t *testing.T) { } } -// TestMount exercises the full path: build the all-agents layout, mount it -// on a ServeMux, and drive the resulting endpoints. It proves every route -// speaks the same enveloped Handler transport (the turn and the getSnapshot -// companion alike) and that a client-managed agent has only its turn route. -func TestMount(t *testing.T) { +// TestRoutesServedOverServeMux exercises the full path: build the all-agents +// layout, wire it onto a ServeMux, and drive the resulting endpoints. It +// proves every route speaks the same enveloped Handler transport (the turn +// and the getSnapshot companion alike) and that a client-managed agent has +// only its turn route. +func TestRoutesServedOverServeMux(t *testing.T) { g := newRouteTestGenkit(t) mux := http.NewServeMux() - Mount(mux, AllAgentRoutes(g)) + for _, route := range AllAgentRoutes(g) { + mux.HandleFunc(route.Pattern(), route.Handler()) + } do := func(t *testing.T, method, path, body string) (int, string) { t.Helper() diff --git a/go/samples/basic-agents-server/main.go b/go/samples/basic-agents-server/main.go index c340bbb70f..e38d292c7f 100644 --- a/go/samples/basic-agents-server/main.go +++ b/go/samples/basic-agents-server/main.go @@ -153,8 +153,8 @@ func main() { ) // genkitx.AllAgentRoutes lays out a default HTTP surface for every - // registered agent, and genkitx.Mount wires it onto the mux. The layout - // follows each agent's capabilities, so server-managed and + // registered agent; range over the routes and wire each onto the mux. + // The layout follows each agent's capabilities, so server-managed and // client-managed agents can be deployed side by side from one call: // // "chat" (store-backed): @@ -166,14 +166,18 @@ func main() { // // Every route is a POST taking the standard {"data": ...} envelope and // returning {"result": ...}; the companions read the snapshotId from - // that body. HandlerOptions passed here (e.g. context providers for - // auth) apply to every route. + // that body. route.Pattern() is its "METHOD /path" and route.Handler() + // builds the genkit.Handler; pass HandlerOptions (e.g. context providers + // for auth) to Handler() to apply them per route. Any router works the + // same way (Gin, Chi, Echo): read Pattern and serve Handler. // // To serve specific agents instead of all of them, use // genkitx.AgentRoutes(agent); to expose flows, genkitx.AllFlowRoutes(g). // Mix them by concatenating the route slices. The genkitx (genkit/exp) // package holds these helpers while the routing layer is experimental. mux := http.NewServeMux() - genkitx.Mount(mux, genkitx.AllAgentRoutes(g)) + for _, route := range genkitx.AllAgentRoutes(g) { + mux.HandleFunc(route.Pattern(), route.Handler()) + } log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) } From 44098deb5961fcb980f5c54deafb4e37dd8a2785 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 19:59:18 -0700 Subject: [PATCH 113/141] chore(go/ai): drop per-call experimental-model log The "model is experimental or unstable" Info log fired on every model call for any unstable-staged model (e.g. unknown/-latest aliases), which is noise. The deprecated-model Warn stays, since deprecation is actionable. --- go/ai/model_middleware.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/go/ai/model_middleware.go b/go/ai/model_middleware.go index ca308c3dcc..ed5aa9340c 100644 --- a/go/ai/model_middleware.go +++ b/go/ai/model_middleware.go @@ -254,13 +254,8 @@ func validateSupport(model string, opts *ModelOptions) ModelMiddleware { } } - if opts.Stage != "" { - switch opts.Stage { - case ModelStageDeprecated: - logger.FromContext(ctx).Warn("model is deprecated and may be removed in a future release", "model", model) - case ModelStageUnstable: - logger.FromContext(ctx).Info("model is experimental or unstable", "model", model) - } + if opts.Stage == ModelStageDeprecated { + logger.FromContext(ctx).Warn("model is deprecated and may be removed in a future release", "model", model) } if (opts.Supports.Constrained == "" || From 8bc383d4c96f4cfa41405d444461f671303635a5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 20:19:35 -0700 Subject: [PATCH 114/141] feat(go/exp): validate agent resume parts against session history A resume payload on AgentInput let a caller drive any tool: nothing checked that a restart/respond entry corresponded to a tool request the model had actually issued and interrupted on, so a client could invoke a tool the model never called, or forge the inputs of an interrupted one. ValidateResumeAgainstHistory scans every model message in the session history and requires each restart/respond entry to match an existing tool request by name + ref; restart entries must additionally carry unchanged inputs (normalized through JSON so int/float64 drift between in-memory history and the deserialized payload is not mistaken for a forgery). Violations return INVALID_ARGUMENT. The prompt-backed loop calls it before generation, so a forged resume fails the turn before the model is invoked; the helper is exported so custom agents that accept resume payloads from untrusted callers can enforce the same guarantee. Mirrors the JS validateResumeAgainstHistory. --- go/ai/exp/agent.go | 106 ++++++++++++++++++++++++-- go/ai/exp/agent_test.go | 160 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 261 insertions(+), 5 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 60ccd45a0d..27ea579a07 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -1835,6 +1835,96 @@ func validateUserMessage(m *ai.Message) error { return nil } +// ValidateResumeAgainstHistory ensures every restart and respond entry on a +// resume payload references a tool request the model actually issued, so a +// caller cannot drive a tool the model never asked for and interrupted on. +// For restart entries it additionally checks the input is unchanged from the +// original request, preventing a client from forging tool inputs on the +// interrupted call. The whole history is searched (every model message), not +// just the last turn. On a violation it returns an INVALID_ARGUMENT error. +// +// The prompt-backed agent loop ([FromPrompt]) calls this automatically. A +// custom agent ([DefineCustomAgent]) that accepts an [AgentInput.Resume] from +// untrusted callers should call it before forwarding the payload to the model: +// +// if input.Resume != nil { +// if err := ValidateResumeAgainstHistory(input.Resume, sess.Messages()); err != nil { +// return nil, err +// } +// } +func ValidateResumeAgainstHistory(resume *ToolResume, history []*ai.Message) error { + if resume == nil { + return nil + } + + // Collect every tool request from all model messages in history. + var requests []*ai.ToolRequest + for _, msg := range history { + if msg == nil || msg.Role != ai.RoleModel { + continue + } + for _, p := range msg.Content { + if p.IsToolRequest() && p.ToolRequest != nil { + requests = append(requests, p.ToolRequest) + } + } + } + find := func(name, ref string) *ai.ToolRequest { + for _, req := range requests { + if req.Name == name && req.Ref == ref { + return req + } + } + return nil + } + + // Restart entries: name + ref must exist and the input must match the + // original request exactly. IsToolRequest only checks the part kind, so + // guard the pointer too: a hand-built NewToolRequestPart(nil) is kind + // PartToolRequest with a nil ToolRequest. + for _, p := range resume.Restart { + if !p.IsToolRequest() || p.ToolRequest == nil { + continue + } + req := p.ToolRequest + match := find(req.Name, req.Ref) + if match == nil { + return core.NewError(core.INVALID_ARGUMENT, + "resume.restart references tool %q%s which was not found in session history", + req.Name, toolRefSuffix(req.Ref)) + } + if !jsonEqual(normalizeJSON(req.Input), normalizeJSON(match.Input)) { + return core.NewError(core.INVALID_ARGUMENT, + "resume.restart for tool %q%s has modified inputs that do not match the original tool request in session history; restart inputs must exactly match the interrupted tool request", + req.Name, toolRefSuffix(req.Ref)) + } + } + + // Respond entries: name + ref must match a tool request in history. + for _, p := range resume.Respond { + if !p.IsToolResponse() || p.ToolResponse == nil { + continue + } + resp := p.ToolResponse + if find(resp.Name, resp.Ref) == nil { + return core.NewError(core.INVALID_ARGUMENT, + "resume.respond references tool %q%s which was not found in session history", + resp.Name, toolRefSuffix(resp.Ref)) + } + } + + return nil +} + +// toolRefSuffix renders a " (ref: X)" clause for resume validation errors, or +// "" when the tool request carried no ref. +func toolRefSuffix(ref string) string { + if ref == "" { + return "" + } + return fmt.Sprintf(" (ref: %s)", ref) +} + // agentLoop returns the per-turn function for a prompt-backed agent. Each // turn renders the prompt, appends conversation history, calls the model // with streaming, and updates the session. @@ -1876,12 +1966,18 @@ func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) Ag } // Append conversation history after the base messages. - actionOpts.Messages = append(base, sess.Messages()...) - - // If a resume payload was provided, forward it to the - // generate call so handleResumeOption re-executes the - // interrupted tools and / or applies the responses. + history := sess.Messages() + actionOpts.Messages = append(base, history...) + + // If a resume payload was provided, validate that every + // restart / respond entry references a tool request the model + // actually issued, then forward it to the generate call so + // handleResumeOption re-executes the interrupted tools and / or + // applies the responses. if input.Resume != nil { + if err := ValidateResumeAgainstHistory(input.Resume, history); err != nil { + return nil, err + } actionOpts.Resume = &ai.GenerateActionResume{ Respond: input.Resume.Respond, Restart: input.Resume.Restart, diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index e0f7435e6f..17cca1054e 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -23,6 +23,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "testing" "time" @@ -1996,6 +1997,165 @@ func TestPromptAgent_RejectsToolResponsePart(t *testing.T) { } } +func TestValidateResumeAgainstHistory(t *testing.T) { + // History spans two model messages (each carrying a tool request) plus a + // user message that must be ignored. The whole history is searched, not + // just the last turn. + history := []*ai.Message{ + {Role: ai.RoleUser, Content: []*ai.Part{ai.NewTextPart("hi")}}, + {Role: ai.RoleModel, Content: []*ai.Part{ + ai.NewToolRequestPart(&ai.ToolRequest{Name: "first", Ref: "r1", Input: map[string]any{"a": float64(1)}}), + }}, + {Role: ai.RoleModel, Content: []*ai.Part{ + ai.NewTextPart("thinking"), + ai.NewToolRequestPart(&ai.ToolRequest{Name: "second", Ref: "r2", Input: map[string]any{"b": "x"}}), + }}, + } + + respond := func(name, ref string) []*ai.Part { + return []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{Name: name, Ref: ref, Output: "ok"})} + } + restart := func(name, ref string, input any) []*ai.Part { + return []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{Name: name, Ref: ref, Input: input})} + } + + tests := []struct { + name string + resume *ToolResume + wantErr string // substring the error must contain; "" means it must succeed + }{ + {name: "nil resume", resume: nil}, + {name: "empty resume", resume: &ToolResume{}}, + {name: "respond matches first model message", resume: &ToolResume{Respond: respond("first", "r1")}}, + {name: "respond matches a later model message", resume: &ToolResume{Respond: respond("second", "r2")}}, + {name: "restart matches input exactly", resume: &ToolResume{Restart: restart("first", "r1", map[string]any{"a": float64(1)})}}, + { + name: "respond references unknown tool", + resume: &ToolResume{Respond: respond("ghost", "r1")}, + wantErr: "not found in session history", + }, + { + name: "respond references known tool with wrong ref", + resume: &ToolResume{Respond: respond("first", "wrong")}, + wantErr: "not found in session history", + }, + { + name: "restart references unknown tool", + resume: &ToolResume{Restart: restart("ghost", "r1", nil)}, + wantErr: "not found in session history", + }, + { + name: "restart forges modified input", + resume: &ToolResume{Restart: restart("first", "r1", map[string]any{"a": float64(2)})}, + wantErr: "modified inputs", + }, + { + // An int 1 normalizes to the same JSON shape as the stored float64 + // 1, so a faithful restart is not mistaken for a forgery. + name: "restart input matches across json number types", + resume: &ToolResume{Restart: restart("first", "r1", map[string]any{"a": 1})}, + }, + { + // A kind-PartToolRequest part with a nil ToolRequest pointer (e.g. + // NewToolRequestPart(nil)) must be skipped, not panic. + name: "restart with nil tool request pointer is skipped", + resume: &ToolResume{Restart: []*ai.Part{ai.NewToolRequestPart(nil), nil}}, + }, + { + name: "respond with nil tool response pointer is skipped", + resume: &ToolResume{Respond: []*ai.Part{ai.NewToolResponsePart(nil), nil}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateResumeAgainstHistory(tc.resume, history) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("error %q does not contain %q", err.Error(), tc.wantErr) + } + if ge := core.AsGenkitError(err); ge.Status != core.INVALID_ARGUMENT { + t.Fatalf("expected status %q, got %q", core.INVALID_ARGUMENT, ge.Status) + } + }) + } +} + +// TestPromptAgent_RejectsResumeForUnrequestedTool proves the resume validation +// is wired into the prompt-backed loop: a caller cannot resume a tool the model +// never requested, and the forged turn fails before the model is re-invoked. +func TestPromptAgent_RejectsResumeForUnrequestedTool(t *testing.T) { + ctx := context.Background() + reg := registry.New() + ai.ConfigureFormats(reg) + + var modelCalls atomic.Int32 + ai.DefineModel(reg, "test/plain", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, Tools: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + modelCalls.Add(1) + return &ai.ModelResponse{Request: req, Message: ai.NewModelTextMessage("hello")}, nil + }) + ai.DefineGenerateAction(ctx, reg) + ai.DefinePrompt(reg, "plainPrompt", ai.WithModelName("test/plain")) + + af := DefineAgent[testState](reg, "plainPrompt", FromPrompt()) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + // Turn 1: a plain text reply, so no tool request lands in history. + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.TurnEnd != nil { + break + } + } + + // Turn 2: forge a resume for a tool the model never requested. + if err := conn.SendResume(&ToolResume{ + Restart: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "inventedTool", + Ref: "fake", + Input: map[string]any{"evil": true}, + })}, + }); err != nil { + t.Fatalf("SendResume: %v", err) + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Fatalf("FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonFailed) + } + if out.Error == nil || out.Error.Status != core.INVALID_ARGUMENT { + t.Fatalf("expected INVALID_ARGUMENT error, got %+v", out.Error) + } + if !strings.Contains(out.Error.Message, "not found in session history") { + t.Errorf("expected not-found error, got %q", out.Error.Message) + } + // The forged turn must be rejected before the model is invoked again. + if got := modelCalls.Load(); got != 1 { + t.Errorf("model calls = %d, want 1 (resume rejected before generate)", got) + } +} + func TestAgent_SingleTurnSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) From 46ab093d272c60134b16aa605f802ec09aaefd95 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 23:15:18 -0700 Subject: [PATCH 115/141] test(go/exp): streamline agent test suite with shared helpers Collapse repeated boilerplate in the agent test suite without narrowing coverage: - add sendText / sendTurn / drainInBackground / defineCounterAgent helpers and apply them across the suite (send blocks, turn-end drains, detach background drainers, and the reply+counter agent body) - replace hand-rolled Receive-until-TurnEnd loops with nextTurnEnd - merge the three TestPromptAgent_Rejects* tests into one table-driven test, also asserting INVALID_ARGUMENT for the tool-part cases - add TestAgentConnection_Custom_TracksStreamedPatches, covering the previously untested client-side live custom-state tracking Test-only change; passes go test under -race and GOMAXPROCS=1 -race. --- go/ai/exp/agent_test.go | 860 ++++++++++++---------------------------- 1 file changed, 258 insertions(+), 602 deletions(-) diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 17cca1054e..a1c9756950 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -43,6 +43,40 @@ func newTestRegistry(t *testing.T) *registry.Registry { return registry.New() } +// sendText sends a user text message, failing the test if the send is +// rejected. The few sites that expect the send to race invocation completion +// (e.g. a send after a failing turn) check the error themselves instead. +func sendText[State any](t *testing.T, conn *AgentConnection[State], text string) { + t.Helper() + if err := conn.SendText(text); err != nil { + t.Fatalf("SendText(%q): %v", text, err) + } +} + +// sendTurn sends a user text message and advances the stream to that turn's +// TurnEnd, returning it. It is the send-one-turn-and-wait pattern most +// multi-turn tests share; tests that must inspect intermediate chunks should +// drive Receive directly (see nextTurnEnd). +func sendTurn[State any](t *testing.T, conn *AgentConnection[State], text string) *TurnEnd { + t.Helper() + sendText(t, conn, text) + return nextTurnEnd(t, conn) +} + +// drainInBackground consumes conn's stream in a goroutine so the agent's +// responder never blocks on a full stream buffer while the test orchestrates +// a detach or cancellation. The goroutine exits when the stream ends or +// errors. Used by tests that don't inspect the streamed chunks. +func drainInBackground[State any](conn *AgentConnection[State]) { + go func() { + for _, err := range conn.Receive() { + if err != nil { + return + } + } + }() +} + func TestAgent_BasicMultiTurn(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) @@ -71,9 +105,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { } // Turn 1. - if err := conn.SendText("hello"); err != nil { - t.Fatalf("SendText failed: %v", err) - } + sendText(t, conn, "hello") var turn1Chunks int for chunk, err := range conn.Receive() { if err != nil { @@ -89,17 +121,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { } // Turn 2. - if err := conn.SendText("world"); err != nil { - t.Fatalf("SendText failed: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "world") conn.Close() @@ -117,45 +139,94 @@ func TestAgent_BasicMultiTurn(t *testing.T) { } } -func TestAgent_WithSessionStore(t *testing.T) { +// TestAgentConnection_Custom_TracksStreamedPatches verifies the client-side +// live custom-state tracking: as Receive yields chunks, AgentConnection applies +// each CustomPatch to an internal copy that Custom() returns. It exercises the +// per-turn whole-document replace (the first patch of a turn re-bases the +// client) followed by an incremental diff within the same turn, and that the +// tracking carries across turns. The server-side patch emission is covered by +// TestAgent_TurnSpanOutput_WithSnapshots; this is its client-side complement. +func TestAgentConnection_Custom_TracksStreamedPatches(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "snapshotFlow", + af := DefineCustomAgent(reg, "liveCustomFlow", func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - if input.Message != nil { - sess.AddMessages(ai.NewModelTextMessage("reply")) - } + // Two mutations per turn: the first emits a whole-document + // replace, the second an incremental diff. sess.UpdateCustom(func(s testState) testState { s.Counter++ return s }) + sess.UpdateCustom(func(s testState) testState { + s.Topics = append(s.Topics, input.Message.Text()) + return s + }) return nil, nil }) }, - WithSessionStore(store), ) + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + + // Before any patch arrives, Custom() is the zero value. + if got, err := conn.Custom(); err != nil || got.Counter != 0 || len(got.Topics) != 0 { + t.Errorf("Custom() before any turn = %+v (err %v), want zero value", got, err) + } + + // Turn 1: draining the turn applies the streamed patches (replace + diff). + sendTurn(t, conn, "alpha") + got, err := conn.Custom() + if err != nil { + t.Fatalf("Custom() after turn 1: %v", err) + } + if got.Counter != 1 || !slices.Equal(got.Topics, []string{"alpha"}) { + t.Errorf("tracked custom after turn 1 = %+v, want {Counter:1 Topics:[alpha]}", got) + } + + // Turn 2: its first patch is a whole-document replace that re-bases the + // client; the cumulative tracked state must still be correct. + sendTurn(t, conn, "beta") + got, err = conn.Custom() + if err != nil { + t.Fatalf("Custom() after turn 2: %v", err) + } + if got.Counter != 2 || !slices.Equal(got.Topics, []string{"alpha", "beta"}) { + t.Errorf("tracked custom after turn 2 = %+v, want {Counter:2 Topics:[alpha beta]}", got) + } + + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + // The authoritative final state agrees with what the streamed patches tracked. + if out.State.Custom.Counter != 2 || !slices.Equal(out.State.Custom.Topics, []string{"alpha", "beta"}) { + t.Errorf("final state custom = %+v, want {Counter:2 Topics:[alpha beta]}", out.State.Custom) + } +} + +func TestAgent_WithSessionStore(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := defineCounterAgent(reg, "snapshotFlow", WithSessionStore(store)) + conn, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("turn1") + sendText(t, conn, "turn1") var snapshotIDs []string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - if chunk.TurnEnd.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) - } - break - } + if te := nextTurnEnd(t, conn); te.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, te.SnapshotID) } if len(snapshotIDs) != 1 { @@ -192,28 +263,14 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "resumeFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - if input.Message != nil { - sess.AddMessages(ai.NewModelTextMessage("reply")) - } - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - WithSessionStore(store), - ) + af := defineCounterAgent(reg, "resumeFlow", WithSessionStore(store)) // First invocation: create a snapshot. conn1, err := af.StreamBidi(ctx) if err != nil { t.Fatalf("StreamBidi failed: %v", err) } - conn1.SendText("first message") + sendText(t, conn1, "first message") for chunk, err := range conn1.Receive() { if err != nil { t.Fatalf("Receive error: %v", err) @@ -236,15 +293,7 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { if err != nil { t.Fatalf("StreamBidi with snapshot failed: %v", err) } - conn2.SendText("continued message") - for chunk, err := range conn2.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn2, "continued message") conn2.Close() resp2, err := conn2.Output() if err != nil { @@ -281,20 +330,7 @@ func TestAgent_ClientManagedState(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "clientStateFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - if input.Message != nil { - sess.AddMessages(ai.NewModelTextMessage("reply")) - } - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - ) + af := defineCounterAgent(reg, "clientStateFlow") // Start with client-provided state. clientState := &SessionState[testState]{ @@ -310,15 +346,7 @@ func TestAgent_ClientManagedState(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("new message") - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "new message") conn.Close() response, err := conn.Output() @@ -380,7 +408,7 @@ func TestAgent_Artifacts(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("generate code") + sendText(t, conn, "generate code") var receivedArtifacts []*Artifact for chunk, err := range conn.Receive() { if err != nil { @@ -493,14 +521,7 @@ func TestAgent_InputMessageCloned(t *testing.T) { if err := conn.SendMessage(sent); err != nil { t.Fatalf("SendMessage failed: %v", err) } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + nextTurnEnd(t, conn) // The turn is over, so the message is in session history. Mutating // the caller's copy must not reach it. @@ -627,14 +648,7 @@ func TestAgent_SendMessage(t *testing.T) { if err != nil { t.Fatalf("SendMessage failed: %v", err) } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + nextTurnEnd(t, conn) conn.Close() response, err := conn.Output() @@ -677,15 +691,7 @@ func TestAgent_SessionContext(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("test") - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "test") conn.Close() conn.Output() @@ -711,7 +717,7 @@ func TestAgent_ErrorInTurn(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("trigger error") + sendText(t, conn, "trigger error") conn.Close() // A failed turn resolves the invocation gracefully rather than @@ -737,6 +743,26 @@ func TestAgent_ErrorInTurn(t *testing.T) { } } +// defineCounterAgent defines a custom agent whose every turn appends a model +// "reply" message and increments the custom counter: the minimal stateful +// turn body shared by the snapshot and state-management tests. opts pass +// through to DefineCustomAgent (e.g. WithSessionStore). +func defineCounterAgent(reg api.Registry, name string, opts ...AgentOption[testState]) *Agent[testState] { + return DefineCustomAgent(reg, name, + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil, nil + }) + }, + opts..., + ) +} + // defineLastGoodTestAgent defines a client- or server-managed echo agent // whose turn fails (with partial session mutations) when the user sends // "boom". Successful turns report [AgentFinishReasonStop]. @@ -781,9 +807,7 @@ func TestAgent_FailedTurn_ClientManagedReturnsLastGoodState(t *testing.T) { t.Fatalf("StreamBidi: %v", err) } for _, text := range []string{"one", "two", "boom"} { - if err := conn.SendText(text); err != nil { - t.Fatalf("SendText(%q): %v", text, err) - } + sendText(t, conn, text) } out, err := conn.Output() @@ -879,17 +903,12 @@ func TestAgent_FailedTurn_ServerManagedReturnsLastTurnSnapshot(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("one"); err != nil { - t.Fatalf("SendText: %v", err) - } - turn0 := nextTurnEnd(t, conn) + turn0 := sendTurn(t, conn, "one") if turn0.SnapshotID == "" { t.Fatal("expected turn 0 snapshot") } - if err := conn.SendText("boom"); err != nil { - t.Fatalf("SendText(boom): %v", err) - } + sendText(t, conn, "boom") out, err := conn.Output() if err != nil { t.Fatalf("Output: %v", err) @@ -974,9 +993,7 @@ func TestAgent_FailedTurn_EmitsFailedTurnEnd(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "hi") turnEnd := nextTurnEnd(t, conn) close(turnEndSeen) @@ -1028,9 +1045,7 @@ func TestAgent_CustomAgentContinuesAfterFailedTurn(t *testing.T) { t.Fatalf("StreamBidi: %v", err) } for _, text := range []string{"one", "boom", "two"} { - if err := conn.SendText(text); err != nil { - t.Fatalf("SendText(%q): %v", text, err) - } + sendText(t, conn, text) } // A hang here means intake pacing after a failed turn is broken. @@ -1140,15 +1155,7 @@ func TestAgent_SetMessages(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("original") - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "original") conn.Close() response, err := conn.Output() @@ -1203,17 +1210,7 @@ func TestAgent_TurnSpanOutput(t *testing.T) { // Two turns. for turn := range 2 { - if err := conn.SendText(fmt.Sprintf("turn %d", turn)); err != nil { - t.Fatalf("SendText failed on turn %d: %v", turn, err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error on turn %d: %v", turn, err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, fmt.Sprintf("turn %d", turn)) } conn.Close() @@ -1276,18 +1273,10 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("hello") + sendText(t, conn, "hello") var sawSnapshot bool - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - if chunk.TurnEnd.SnapshotID != "" { - sawSnapshot = true - } - break - } + if nextTurnEnd(t, conn).SnapshotID != "" { + sawSnapshot = true } conn.Close() conn.Output() @@ -1368,9 +1357,7 @@ func TestPromptAgent_Basic(t *testing.T) { } // Turn 1. - if err := conn.SendText("hello"); err != nil { - t.Fatalf("SendText failed: %v", err) - } + sendText(t, conn, "hello") var gotChunk bool for chunk, err := range conn.Receive() { @@ -1389,17 +1376,7 @@ func TestPromptAgent_Basic(t *testing.T) { } // Turn 2. - if err := conn.SendText("world"); err != nil { - t.Fatalf("SendText failed: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "world") conn.Close() @@ -1455,7 +1432,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { } // Turn 1. - conn.SendText("turn1") + sendText(t, conn, "turn1") var turn1Response string for chunk, err := range conn.Receive() { if err != nil { @@ -1476,7 +1453,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { } // Turn 2. - conn.SendText("turn2") + sendText(t, conn, "turn2") var turn2Response string for chunk, err := range conn.Receive() { if err != nil { @@ -1530,15 +1507,7 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("hello") - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "hello") conn.Close() resp, err := conn.Output() @@ -1554,15 +1523,7 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { t.Fatalf("StreamBidi with snapshot failed: %v", err) } - conn2.SendText("continued") - for chunk, err := range conn2.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn2, "continued") conn2.Close() resp2, err := conn2.Output() @@ -1669,15 +1630,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { t.Fatalf("StreamBidi failed: %v", err) } - conn.SendText("go") - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "go") conn.Close() response, err := conn.Output() @@ -1806,18 +1759,7 @@ func TestAgent_RunText_WithState(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - af := DefineCustomAgent(reg, "runStateFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.AddMessages(ai.NewModelTextMessage("reply")) - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - ) + af := defineCounterAgent(reg, "runStateFlow") clientState := &SessionState[testState]{ Messages: []*ai.Message{ @@ -1847,19 +1789,7 @@ func TestAgent_RunText_WithSnapshot(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "runSnapshotFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.AddMessages(ai.NewModelTextMessage("reply")) - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - WithSessionStore(store), - ) + af := defineCounterAgent(reg, "runSnapshotFlow", WithSessionStore(store)) // First invocation via RunText. resp1, err := af.RunText(ctx, "first") @@ -1914,86 +1844,62 @@ func TestPromptAgent_RunText(t *testing.T) { } } -func TestPromptAgent_RejectsNonUserRole(t *testing.T) { +// TestPromptAgent_RejectsInvalidInputMessage verifies the prompt-backed loop +// rejects turn messages it cannot safely consume: a non-user role, or tool +// request/response parts (which belong on AgentInput.Resume, not a turn +// message). Each resolves as a failed output carrying an INVALID_ARGUMENT +// error rather than reaching the model. +func TestPromptAgent_RejectsInvalidInputMessage(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) + ai.DefinePrompt(reg, "rejectPrompt", ai.WithModelName("test/echo")) + af := DefineAgent[testState](reg, "rejectPrompt", FromPrompt()) - ai.DefinePrompt(reg, "rejectRolePrompt", ai.WithModelName("test/echo")) - af := DefineAgent[testState](reg, "rejectRolePrompt", FromPrompt()) - - out, err := af.Run(ctx, &AgentInput{ - Message: &ai.Message{ - Role: ai.RoleModel, - Content: []*ai.Part{ai.NewTextPart("hi")}, + tests := []struct { + name string + message *ai.Message + wantMsg string + }{ + { + name: "non-user role", + message: &ai.Message{Role: ai.RoleModel, Content: []*ai.Part{ai.NewTextPart("hi")}}, + wantMsg: "role", }, - }) - if err != nil { - t.Fatalf("Run: %v", err) - } - if out.FinishReason != AgentFinishReasonFailed { - t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) - } - if out.Error == nil { - t.Fatal("expected output error for non-user role, got nil") - } - if out.Error.Status != core.INVALID_ARGUMENT { - t.Errorf("expected status %q, got %q", core.INVALID_ARGUMENT, out.Error.Status) - } - if !strings.Contains(out.Error.Message, "role") { - t.Errorf("expected role-related error, got %v", out.Error) - } -} - -func TestPromptAgent_RejectsToolRequestPart(t *testing.T) { - ctx := context.Background() - reg := setupPromptTestRegistry(t) - - ai.DefinePrompt(reg, "rejectToolReqPrompt", ai.WithModelName("test/echo")) - af := DefineAgent[testState](reg, "rejectToolReqPrompt", FromPrompt()) - - out, err := af.Run(ctx, &AgentInput{ - Message: &ai.Message{ - Role: ai.RoleUser, - Content: []*ai.Part{ + { + name: "tool request part", + message: &ai.Message{Role: ai.RoleUser, Content: []*ai.Part{ ai.NewTextPart("hi"), ai.NewToolRequestPart(&ai.ToolRequest{Name: "doThing", Ref: "1"}), - }, + }}, + wantMsg: "tool request", }, - }) - if err != nil { - t.Fatalf("Run: %v", err) - } - if out.FinishReason != AgentFinishReasonFailed { - t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) - } - if out.Error == nil || !strings.Contains(out.Error.Message, "tool request") { - t.Errorf("expected tool-request error, got %+v", out.Error) - } -} - -func TestPromptAgent_RejectsToolResponsePart(t *testing.T) { - ctx := context.Background() - reg := setupPromptTestRegistry(t) - - ai.DefinePrompt(reg, "rejectToolRespPrompt", ai.WithModelName("test/echo")) - af := DefineAgent[testState](reg, "rejectToolRespPrompt", FromPrompt()) - - out, err := af.Run(ctx, &AgentInput{ - Message: &ai.Message{ - Role: ai.RoleUser, - Content: []*ai.Part{ + { + name: "tool response part", + message: &ai.Message{Role: ai.RoleUser, Content: []*ai.Part{ ai.NewToolResponsePart(&ai.ToolResponse{Name: "doThing", Ref: "1"}), - }, + }}, + wantMsg: "tool", }, - }) - if err != nil { - t.Fatalf("Run: %v", err) - } - if out.FinishReason != AgentFinishReasonFailed { - t.Errorf("expected finish reason %q, got %q", AgentFinishReasonFailed, out.FinishReason) } - if out.Error == nil || !strings.Contains(out.Error.Message, "tool") { - t.Errorf("expected tool-related error, got %+v", out.Error) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + out, err := af.Run(ctx, &AgentInput{Message: tc.message}) + if err != nil { + t.Fatalf("Run: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonFailed) + } + if out.Error == nil { + t.Fatal("expected output error, got nil") + } + if out.Error.Status != core.INVALID_ARGUMENT { + t.Errorf("Error.Status = %q, want %q", out.Error.Status, core.INVALID_ARGUMENT) + } + if !strings.Contains(out.Error.Message, tc.wantMsg) { + t.Errorf("Error.Message = %q, want substring %q", out.Error.Message, tc.wantMsg) + } + }) } } @@ -2114,17 +2020,7 @@ func TestPromptAgent_RejectsResumeForUnrequestedTool(t *testing.T) { } // Turn 1: a plain text reply, so no tool request lands in history. - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "hi") // Turn 2: forge a resume for a tool the model never requested. if err := conn.SendResume(&ToolResume{ @@ -2161,19 +2057,7 @@ func TestAgent_SingleTurnSnapshot(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "singleTurnFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.AddMessages(ai.NewModelTextMessage("reply")) - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - WithSessionStore(store), - ) + af := defineCounterAgent(reg, "singleTurnFlow", WithSessionStore(store)) // Single-turn invocation: exactly 1 snapshot (the turn-end), which the // output reuses as its resume point. There is no second invocation-end @@ -2205,19 +2089,7 @@ func TestAgent_MultiTurnSnapshot(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "multiDedupFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.AddMessages(ai.NewModelTextMessage("reply")) - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - WithSessionStore(store), - ) + af := defineCounterAgent(reg, "multiDedupFlow", WithSessionStore(store)) // Multi-turn: one snapshot per turn; the output reuses the last one. conn, err := af.StreamBidi(ctx) @@ -2227,17 +2099,9 @@ func TestAgent_MultiTurnSnapshot(t *testing.T) { var snapshotIDs []string for i := 0; i < 3; i++ { - conn.SendText(fmt.Sprintf("turn %d", i)) - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error on turn %d: %v", i, err) - } - if chunk.TurnEnd != nil { - if chunk.TurnEnd.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) - } - break - } + sendText(t, conn, fmt.Sprintf("turn %d", i)) + if te := nextTurnEnd(t, conn); te.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, te.SnapshotID) } } conn.Close() @@ -2339,9 +2203,7 @@ func TestAgent_FnPanicResolvesAsFailedOutput(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("trigger"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "trigger") // A hang here means the streaming goroutine leaked. out, err := outputWithin(t, conn, 2*time.Second) @@ -2394,9 +2256,7 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "go") <-emitting cancel() @@ -2501,18 +2361,8 @@ func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { var observed []TurnEnd for turn := 0; turn < 3; turn++ { - if err := conn.SendText(fmt.Sprintf("turn %d", turn)); err != nil { - t.Fatalf("SendText: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive: %v", err) - } - if chunk.TurnEnd != nil { - observed = append(observed, *chunk.TurnEnd) - break - } - } + sendText(t, conn, fmt.Sprintf("turn %d", turn)) + observed = append(observed, *nextTurnEnd(t, conn)) } conn.Close() if _, err := conn.Output(); err != nil { @@ -2572,19 +2422,11 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { } // Drain stream chunks in the background. - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) // Send A and wait for it to enter fn (so it's in-flight when detach // arrives). - if err := conn.SendText("A"); err != nil { - t.Fatalf("SendText A: %v", err) - } + sendText(t, conn, "A") select { case <-entered: case <-time.After(2 * time.Second): @@ -2593,9 +2435,7 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { // Send D, then Detach. The eager intake reader sees D queued and the // detach signal immediately, even though the runner is blocked on A. - if err := conn.SendText("D"); err != nil { - t.Fatalf("SendText D: %v", err) - } + sendText(t, conn, "D") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -2667,20 +2507,12 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { } // Background drainer. - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) // Run two normal turns. for i := 0; i < 2; i++ { release <- struct{}{} // pre-load release so this turn's fn doesn't block - if err := conn.SendText(fmt.Sprintf("sync-%d", i)); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, fmt.Sprintf("sync-%d", i)) <-enter } // Brief settle so the second turn-end snapshot lands before detach. @@ -2692,15 +2524,11 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { // Now start a third turn but DON'T release it — the third turn is // in-flight when detach lands. - if err := conn.SendText("inflight"); err != nil { - t.Fatalf("SendText inflight: %v", err) - } + sendText(t, conn, "inflight") <-enter // third turn entered fn // Send the queued input and detach. - if err := conn.SendText("detach-msg"); err != nil { - t.Fatalf("SendText detach-msg: %v", err) - } + sendText(t, conn, "detach-msg") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -2793,17 +2621,9 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { } // Drain chunks so the responder isn't blocked. - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -2890,17 +2710,9 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -2958,17 +2770,9 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -3029,17 +2833,9 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -3096,20 +2892,10 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "hi") var turnEndID string - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive: %v", err) - } - if chunk.TurnEnd != nil { - turnEndID = chunk.TurnEnd.SnapshotID - break - } - } + turnEndID = nextTurnEnd(t, conn).SnapshotID if turnEndID == "" { t.Fatal("expected snapshot ID on TurnEnd chunk") } @@ -3158,17 +2944,9 @@ func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "go") <-entered cancel() @@ -4107,19 +3885,7 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - af := DefineCustomAgent(reg, "resumeDetachedFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.AddMessages(ai.NewModelTextMessage("reply")) - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - return nil, nil - }) - }, - WithSessionStore(store), - ) + af := defineCounterAgent(reg, "resumeDetachedFlow", WithSessionStore(store)) ctx := context.Background() @@ -4129,16 +3895,8 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() - if err := conn.SendText("turn 1"); err != nil { - t.Fatalf("SendText: %v", err) - } + drainInBackground(conn) + sendText(t, conn, "turn 1") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -4271,17 +4029,9 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -4539,9 +4289,7 @@ func TestAgent_FinishReason_TurnAndInvocation(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "hi") turnEnd := nextTurnEnd(t, conn) if turnEnd.FinishReason != AgentFinishReasonStop { @@ -4588,10 +4336,7 @@ func TestAgent_FinishReason_OmittedWhenNil(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } - turnEnd := nextTurnEnd(t, conn) + turnEnd := sendTurn(t, conn, "hi") if turnEnd.FinishReason != "" { t.Errorf("TurnEnd.FinishReason = %q, want empty", turnEnd.FinishReason) } @@ -4629,10 +4374,7 @@ func TestAgent_FinishReason_InvocationOverride(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } - turnEnd := nextTurnEnd(t, conn) + turnEnd := sendTurn(t, conn, "hi") if turnEnd.FinishReason != AgentFinishReasonStop { t.Errorf("TurnEnd.FinishReason = %q, want %q (per-turn, unaffected by override)", turnEnd.FinishReason, AgentFinishReasonStop) } @@ -4671,9 +4413,7 @@ func TestAgent_FinishReason_MultiTurnDistinct(t *testing.T) { var got []AgentFinishReason for i := 0; i < len(reasons); i++ { - if err := conn.SendText("turn"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "turn") got = append(got, nextTurnEnd(t, conn).FinishReason) } for i, want := range reasons { @@ -4715,10 +4455,7 @@ func TestPromptAgent_ForwardsFinishReason(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } - turnEnd := nextTurnEnd(t, conn) + turnEnd := sendTurn(t, conn, "hi") if turnEnd.FinishReason != AgentFinishReasonLength { t.Errorf("TurnEnd.FinishReason = %q, want %q", turnEnd.FinishReason, AgentFinishReasonLength) } @@ -4769,13 +4506,7 @@ func TestAgent_Detach_BackgroundWorkSurvivesActionReturn(t *testing.T) { if err != nil { t.Fatalf("iteration %d: StreamBidi: %v", i, err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() + drainInBackground(conn) if err := conn.SendText("go"); err != nil { t.Fatalf("iteration %d: SendText: %v", i, err) } @@ -4832,16 +4563,8 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + drainInBackground(conn) + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -4888,16 +4611,8 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + drainInBackground(conn) + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -4943,16 +4658,8 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + drainInBackground(conn) + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -5050,10 +4757,7 @@ func TestAgent_FinishReason_MultiTurnDistinct_Persisted(t *testing.T) { } var ids []string for i := 0; i < len(reasons); i++ { - if err := conn.SendText("turn"); err != nil { - t.Fatalf("SendText: %v", err) - } - te := nextTurnEnd(t, conn) + te := sendTurn(t, conn, "turn") if te.FinishReason != reasons[i] { t.Errorf("turn %d TurnEnd.FinishReason = %q, want %q", i, te.FinishReason, reasons[i]) } @@ -5099,9 +4803,7 @@ func TestAgent_FinishReason_OmittedPersisted(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("hi"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "hi") snapID := nextTurnEnd(t, conn).SnapshotID if _, err := conn.Output(); err != nil { t.Fatalf("Output: %v", err) @@ -5153,9 +4855,7 @@ func TestPromptAgent_ForwardsInterruptedFinishReason(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("do it"); err != nil { - t.Fatalf("SendText: %v", err) - } + sendText(t, conn, "do it") var ( turnEnd *TurnEnd gotToolChunk bool @@ -5223,16 +4923,8 @@ func TestAgent_Detach_CompletedHonorsResultOverride(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + drainInBackground(conn) + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -5272,10 +4964,7 @@ func TestAgent_SessionID_AssignedAndStable(t *testing.T) { var snapshotIDs []string for _, text := range []string{"turn one", "turn two"} { - if err := conn.SendText(text); err != nil { - t.Fatalf("SendText: %v", err) - } - te := nextTurnEnd(t, conn) + te := sendTurn(t, conn, text) if te.SnapshotID != "" { snapshotIDs = append(snapshotIDs, te.SnapshotID) } @@ -5602,9 +5291,7 @@ func TestAgent_ResumeFromSessionID_AfterFailureResumesLastTurn(t *testing.T) { t.Fatalf("StreamBidi: %v", err) } for _, text := range []string{"one", "two", "boom"} { - if err := conn.SendText(text); err != nil { - t.Fatalf("SendText(%q): %v", text, err) - } + sendText(t, conn, text) } out, err := conn.Output() if err != nil { @@ -5844,16 +5531,8 @@ func TestAgent_Detach_AssignsSessionID(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() - if err := conn.SendText("go"); err != nil { - t.Fatalf("SendText: %v", err) - } + drainInBackground(conn) + sendText(t, conn, "go") if err := conn.Detach(); err != nil { t.Fatalf("Detach: %v", err) } @@ -5925,16 +5604,8 @@ func TestAgent_Detach_WaitsForInFlightTurnSnapshot(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - go func() { - for _, err := range conn.Receive() { - if err != nil { - return - } - } - }() - if err := conn.SendText("one"); err != nil { - t.Fatalf("SendText: %v", err) - } + drainInBackground(conn) + sendText(t, conn, "one") // Wait until the turn-end snapshot write is in flight, then detach // while it is still blocked inside the store. <-store.entered @@ -5998,10 +5669,7 @@ func TestAgent_FailedTurn_OutputCarriesSessionID(t *testing.T) { if err != nil { t.Fatalf("StreamBidi: %v", err) } - if err := conn.SendText("turn one"); err != nil { - t.Fatalf("SendText: %v", err) - } - nextTurnEnd(t, conn) + sendTurn(t, conn, "turn one") if err := conn.SendText("boom"); err != nil && !errors.Is(err, core.ErrActionCompleted) { t.Fatalf("SendText: %v", err) } @@ -6222,17 +5890,7 @@ func TestAgent_SendNilInput_Rejected(t *testing.T) { } // The connection must remain usable after the rejected input. - if err := conn.SendText("hello"); err != nil { - t.Fatalf("SendText failed: %v", err) - } - for chunk, err := range conn.Receive() { - if err != nil { - t.Fatalf("Receive error: %v", err) - } - if chunk.TurnEnd != nil { - break - } - } + sendTurn(t, conn, "hello") conn.Close() response, err := conn.Output() @@ -6324,9 +5982,7 @@ func TestAgent_ClientCancelMidStream(t *testing.T) { if err != nil { t.Fatalf("StreamBidi failed: %v", err) } - if err := conn.SendText("hello"); err != nil { - t.Fatalf("SendText failed: %v", err) - } + sendText(t, conn, "hello") // Close the input side so sess.Run ends cleanly and fn returns // nil once its sends are accepted. conn.Close() From 21924e9552573a564f2ff3cbd5b3278032a6a65c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 16 Jun 2026 23:20:23 -0700 Subject: [PATCH 116/141] refactor(go/exp): unexport SessionRunner.inputCh and turnIndex These were exported for "advanced use cases" but are runtime internals we don't want in the public surface. Unexport both fields (InputCh -> inputCh, TurnIndex -> turnIndex) and update internal references. The two finish-reason tests that read sess.TurnIndex now track their own turn counter, matching what a real custom agent must do once the field is no longer reachable. Safe under the exp package's documented instability; no samples or other packages referenced these fields. --- go/ai/exp/agent.go | 23 +++++++++-------------- go/ai/exp/agent_test.go | 10 ++++++++-- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 27ea579a07..0e587d452f 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -47,16 +47,11 @@ import ( type SessionRunner[State any] struct { *Session[State] - // InputCh is the channel that delivers per-turn inputs from the client. - // It is consumed automatically by [SessionRunner.Run], but is exposed - // for advanced use cases that need direct access to the input stream - // (e.g., custom turn loops or fan-out patterns). - InputCh <-chan *AgentInput - // TurnIndex is the zero-based index of the current conversation turn. - // It is incremented automatically by [SessionRunner.Run], but is exposed - // for advanced use cases that need to track or manipulate turn ordering - // directly. - TurnIndex int + // inputCh delivers per-turn inputs from the client; consumed by Run. + inputCh <-chan *AgentInput + // turnIndex is the zero-based index of the current conversation turn, + // incremented by Run after each turn completes. + turnIndex int onStartTurn func() onEndTurn func(ctx context.Context) @@ -151,7 +146,7 @@ type TurnResult struct { // failed [AgentOutput] carrying the error and the last-good state rather // than failing the action. func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentInput) (*TurnResult, error)) error { - for input := range s.InputCh { + for input := range s.inputCh { // Deep-copy at the framework boundary: an in-process caller // retains the pointers it sent (message, resume parts) and may // mutate them after Send returns, so everything past this point @@ -164,7 +159,7 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont s.onStartTurn() } spanMeta := &tracing.SpanMetadata{ - Name: fmt.Sprintf("agent/turn/%d", s.TurnIndex), + Name: fmt.Sprintf("agent/turn/%d", s.turnIndex), Type: "flowStep", Subtype: "flowStep", } @@ -209,7 +204,7 @@ func (s *SessionRunner[State]) endTurn(ctx context.Context, reason AgentFinishRe if !failed { s.captureLastGood() } - s.TurnIndex++ + s.turnIndex++ } // captureLastGood deep-copies the committed session state as the @@ -778,7 +773,7 @@ func newAgentRuntime[State any]( rt.sess = &SessionRunner[State]{ Session: session, - InputCh: rt.intake.out(), + inputCh: rt.intake.out(), } if parent != nil { // Resumed: chain the first turn's snapshot off the one we loaded, and diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index a1c9756950..a0060948ce 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -4397,11 +4397,14 @@ func TestAgent_FinishReason_MultiTurnDistinct(t *testing.T) { // Turn 0 reports "stop"; turn 1 reports "interrupted". reasons := []AgentFinishReason{AgentFinishReasonStop, AgentFinishReasonInterrupted} + turn := 0 af := DefineCustomAgent(reg, "multiReasonFlow", func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) - return &TurnResult{FinishReason: reasons[sess.TurnIndex]}, nil + r := reasons[turn] + turn++ + return &TurnResult{FinishReason: r}, nil }) }, ) @@ -4741,11 +4744,14 @@ func TestAgent_FinishReason_MultiTurnDistinct_Persisted(t *testing.T) { store := newTestInMemStore[testState]() reasons := []AgentFinishReason{AgentFinishReasonStop, AgentFinishReasonInterrupted} + turn := 0 af := DefineCustomAgent(reg, "multiReasonPersistedFlow", func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.AddMessages(ai.NewModelTextMessage("ok")) - return &TurnResult{FinishReason: reasons[sess.TurnIndex]}, nil + r := reasons[turn] + turn++ + return &TurnResult{FinishReason: r}, nil }) }, WithSessionStore(store), From 9a0d09bea0424777f024fb04a1bcc8b520c23a5b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 17 Jun 2026 12:07:09 -0700 Subject: [PATCH 117/141] fix(go/ai): emit parallel tool responses in request order Tool calls execute concurrently, so handleToolRequests received their results from resultChan in completion order and appended them in that order. With multiple tool requests in one model message the recorded tool message was nondeterministic (e.g. [ref2, ref1] for requests [ref1, ref2]). Collect responses keyed by each request's position in the model message and emit them in request order, so the tool-loop history is deterministic and every response pairs with its request positionally. --- go/ai/generate.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index c4d6128eae..d515cc143a 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -1016,7 +1016,10 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, }(i, part) } - var toolResps []*Part + // Tools run concurrently, so resultChan delivers responses in completion + // order. Collect them keyed by the request's position in the model message + // so they can be re-emitted in request order below. + toolRespByIndex := make(map[int]*Part, toolCount) hasInterrupts := false for range toolCount { res := <-resultChan @@ -1038,13 +1041,23 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, Content: res.value.Content, }) newToolResp.Metadata = res.value.Metadata - toolResps = append(toolResps, newToolResp) + toolRespByIndex[res.index] = newToolResp } if hasInterrupts { return nil, revisedMsg, nil } + // Emit tool responses in the order their requests appear in the model + // message, not the order the goroutines happened to finish, so the recorded + // tool message is deterministic across runs. + toolResps := make([]*Part, 0, len(toolRespByIndex)) + for i := range revisedMsg.Content { + if part, ok := toolRespByIndex[i]; ok { + toolResps = append(toolResps, part) + } + } + toolMsg.Content = toolResps if cb != nil { From b63375a7c24b914aa5c80864ae67d2b615c7a630 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 18 Jun 2026 08:49:00 -0700 Subject: [PATCH 118/141] docs(go/exp): tighten godocs for the agent API Audit and trim the public godocs across the experimental agent API so they read well in the Go package manager. Cut implementation-iteration notes and internal-mechanics reassurance (JS-applier parity, snapshot choreography, omitempty rationale, reflection-API round-tripping), dial back the unruly field and option comments, and keep the actionable contract detail (store-interface contracts, mutual-exclusivity rules, resume-validation semantics) stated concisely. Generated-type comments in go/ai/exp/gen.go are sourced from doc blocks in go/core/schemas.config; those were edited and the file regenerated. The TS schema and genkit-schema.json are untouched. Documentation only, no behavior change. --- go/ai/exp/agent.go | 107 ++++++++++--------------- go/ai/exp/gen.go | 130 +++++++++++-------------------- go/ai/exp/jsonpatch.go | 9 +-- go/ai/exp/localstore/file.go | 38 +++------ go/ai/exp/localstore/inmemory.go | 11 +-- go/ai/exp/option.go | 91 +++++++++------------- go/ai/exp/session.go | 85 +++++++------------- go/core/schemas.config | 130 +++++++++++-------------------- 8 files changed, 216 insertions(+), 385 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 0e587d452f..6ff830fa67 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -312,19 +312,15 @@ func (s *SessionRunner[State]) snapshotTurnEnd(ctx context.Context, finishReason // --- Responder --- -// Responder is the output channel for an agent. Artifacts sent through -// it are added to the session synchronously: by the time a Send method -// returns, the chunk's session-level side effects have been applied, so -// a state read ([SessionRunner.Result], [Session.Artifacts]) or a -// turn-end snapshot that follows the call observes them. Only the wire -// forward to the client is asynchronous. +// Responder is an agent's output channel to the client. Its Send methods are +// fire-and-forget: they return no error, and the agent function should stop +// producing once its own context is cancelled. // -// All Send methods are ctx-aware: if the agent's work context is -// cancelled (typically client disconnect, abort during detach, or fn -// completion), Send returns promptly with the chunk dropped from the -// wire; the session-level side effects still apply. Send itself remains -// fire-and-forget and returns no error; the user fn is expected to -// observe cancellation through its own ctx check and stop producing. +// A Send applies its session-level side effects synchronously, so a state read +// ([SessionRunner.Result], [Session.Artifacts]) or turn-end snapshot taken +// afterward observes them. Only the forward to the client is asynchronous, and +// it is dropped once the work context is cancelled (client disconnect, abort, +// or agent completion); the side effects still apply. type Responder struct { in chan<- *AgentStreamChunk ctx context.Context @@ -340,14 +336,9 @@ func (r Responder) SendModelChunk(chunk *ai.ModelResponseChunk) { r.send(&AgentStreamChunk{ModelChunk: chunk}) } -// SendArtifact sends an artifact to the stream and adds it to the session. -// If an artifact with the same name already exists in the session, it is -// replaced. The artifact is in the session by the time SendArtifact -// returns, and the session stores a deep copy captured at the call, so -// later mutations of the caller's artifact do not affect session state. -// The session-level side effect happens whether or not detach has landed; -// only the wire forward to the client is suppressed post-detach, when -// there is no longer a client to receive it. +// SendArtifact streams an artifact to the client and adds it to the session, +// replacing any existing artifact with the same name. The session keeps a deep +// copy, so later mutations of artifact do not affect session state. func (r Responder) SendArtifact(artifact *Artifact) { r.send(&AgentStreamChunk{Artifact: artifact}) } @@ -383,11 +374,10 @@ type AgentFunc[State any] = func(ctx context.Context, resp Responder, sess *Sess // Agent is a bidirectional streaming agent with automatic snapshot management. // -// Agent implements [api.BidiAction], so generic transports accept it -// directly (e.g. pass it to genkit.Handler to serve it over HTTP, one turn -// per request). The [Agent.Run], [Agent.RunText], and [Agent.StreamBidi] -// methods are typed conveniences over the same underlying action; both -// surfaces run the identical per-invocation runtime. +// Agent implements [api.BidiAction], so generic transports accept it directly +// (e.g. pass it to genkit.Handler to serve it over HTTP, one turn per request). +// [Agent.Run], [Agent.RunText], and [Agent.StreamBidi] are typed conveniences +// over the same underlying action. // // Server-managed agents (those with a [SessionStore] configured) also // register companion actions for the snapshot lifecycle, available via @@ -573,11 +563,10 @@ func DefineAgent[State any]( // registered conditionally, or moved between registries). For the common // case, [DefineCustomAgent] creates and registers in one step. // -// There is no NewAgent counterpart for prompt-backed agents: a prompt is -// bound to the registry it renders and generates against, so a -// prompt-backed agent cannot be built before it has one. To get -// prompt-like behavior without registration, write a custom agent that -// renders and generates with your own [genkit.Genkit] inside fn. +// There is no NewAgent for prompt-backed agents: a prompt is bound to the +// registry it renders against, so it cannot be built before one exists. For +// prompt-like behavior without registration, render and generate with your own +// [genkit.Genkit] inside a custom fn. func NewCustomAgent[State any]( name string, fn AgentFunc[State], @@ -2117,15 +2106,14 @@ func (a *Agent[State]) resolveOptions(opts []InvocationOption[State]) (*AgentIni // --- AgentConnection --- -// AgentConnection wraps BidiConnection with agent-specific Send helpers -// (SendMessage / SendText / SendResume / Detach) and an Output that -// always waits for finalization (so detached invocations see the -// pending snapshot ID rather than a context-cancellation error). +// AgentConnection is an active agent invocation with bidirectional streaming, +// adding agent-specific Send helpers (SendMessage, SendText, SendResume, +// Detach) over the core connection. // -// It also tracks the conversation's custom state live: as [AgentConnection.Receive] -// yields chunks, it applies each chunk's [AgentStreamChunk.CustomPatch] to an -// internal copy, exposed by [AgentConnection.Custom], so callers observe custom -// state as it streams without applying patches themselves. +// It also tracks custom state live: as [AgentConnection.Receive] yields chunks, +// it applies each chunk's [AgentStreamChunk.CustomPatch] to an internal copy, +// exposed by [AgentConnection.Custom], so callers see custom state as it +// streams without applying patches themselves. type AgentConnection[State any] struct { conn *core.BidiConnection[*AgentInput, *AgentOutput[State], *AgentStreamChunk] @@ -2173,12 +2161,11 @@ func (c *AgentConnection[State]) SendResume(resume *ToolResume) error { // finalized with the cumulative final state once the queued inputs // are processed. // -// Streamed chunks emitted after detach are not forwarded over the wire -// (the connection is gone), but their session-level side effects still -// apply: artifacts sent via [Responder.SendArtifact] land in the -// session and end up in the final snapshot's state. +// Chunks emitted after detach are not forwarded over the wire, but their +// session-level side effects still apply: an artifact sent via +// [Responder.SendArtifact] still lands in the final snapshot's state. // -// To send a final input as part of the same wire message, use +// To send a final input in the same wire message, call // Send(&AgentInput{Detach: true, Message: ...}) directly. func (c *AgentConnection[State]) Detach() error { return c.conn.Send(&AgentInput{Detach: true}) @@ -2245,29 +2232,21 @@ func (c *AgentConnection[State]) Custom() (State, error) { return out, nil } -// Output finalizes the connection and returns the agent's result. +// Output finalizes the connection and returns the agent's result. It closes +// the input side, drains any chunks not consumed via Receive, and blocks until +// the agent finalizes. It is idempotent: later calls return the same value, and +// the returned pointer is shared, so treat it as read-only. // -// Output is the single "I'm done" call: it implicitly closes the input -// side, drains any chunks the caller did not consume via Receive, and -// blocks until the agent finalizes. Calling Close first is allowed but -// redundant. Output is idempotent: subsequent calls return the same -// (*AgentOutput, error); the returned pointer is shared across calls, -// so treat it as read-only. +// In-band failures resolve rather than error. A failed turn returns an +// [AgentOutput] with [AgentFinishReasonFailed], the error on [AgentOutput.Error], +// and the last-good state on [AgentOutput.State] (client-managed) or behind +// [AgentOutput.SnapshotID] (server-managed), so a failure costs only the failed +// turn, not the session. A detached invocation resolves with the pending +// snapshot ID. A non-nil error means the invocation never started (a rejected +// init payload) or could not run to a result (e.g. its context was cancelled). // -// In-band failures resolve rather than error: a failed turn returns an -// [AgentOutput] with [AgentFinishReasonFailed], the error on -// [AgentOutput.Error] (original status intact), and the last-good state -// on [AgentOutput.State] (client-managed) or behind -// [AgentOutput.SnapshotID] (server-managed), so a failure costs the -// caller only the failed turn, never the session. A detached invocation -// resolves with the pending snapshot ID rather than a cancellation -// error. A non-nil error here means the invocation never started (a -// rejected init payload) or could not run to a result (e.g. the -// connection's context was cancelled). -// -// Do not call Output concurrently with a goroutine iterating Receive; -// both consume from the same stream and chunks would be split between -// them. Finish Receive first, then call Output. +// Do not call Output concurrently with a goroutine iterating Receive; both +// consume the stream and would split chunks between them. Finish Receive first. func (c *AgentConnection[State]) Output() (*AgentOutput[State], error) { _ = c.conn.Close() // The core connection applies backpressure and its Output does not diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 6634f509f2..4953518d4c 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -95,45 +95,33 @@ const ( // // Sending no fields starts a fresh invocation with empty state. type AgentInit[State any] struct { - // SessionID identifies the session (conversation) to resume or start. - // Only valid when the agent is server-managed (a session store is - // configured); mutually exclusive with State (a client-managed - // conversation carries its identity inside [SessionState.SessionID]). - // Alone, it resumes the session from its latest snapshot: the most - // recently updated row, whatever its status. If that row is a failed, - // aborted, or still-pending dead end the resume is rejected (pass - // SnapshotID to continue from a specific earlier point); if the session's - // history was forked by resuming an earlier snapshot again, the most - // recently updated branch wins. If the session has no snapshots yet, a - // brand-new conversation is started under this caller-chosen ID, and - // every snapshot it persists carries it. Combined with SnapshotID, it - // asserts which session the snapshot belongs to, and a mismatch is - // rejected. + // SessionID identifies the session (conversation) to resume or start. Only + // valid when the agent is server-managed (a session store is configured); + // mutually exclusive with State. Alone, it resumes the session's latest + // snapshot, rejected if that snapshot is a failed, aborted, or pending dead + // end. If the session has no snapshots yet, a fresh conversation starts under + // this caller-chosen ID. Combined with SnapshotID, it asserts the snapshot + // belongs to that session. SessionID string `json:"sessionId,omitempty"` // SnapshotID loads state from a persisted snapshot. Only valid when the // agent is server-managed (a session store is configured). May be // combined with SessionID to validate that the snapshot belongs to that // session. Mutually exclusive with State. SnapshotID string `json:"snapshotId,omitempty"` - // State provides direct state for the invocation. Only valid when the - // agent is client-managed (no session store). The conversation's - // identity rides inside it ([SessionState.SessionID]): the framework - // mints one on the conversation's first invocation and echoes it on the - // output state, so resending the state object keeps the identity without - // tracking a separate field. Mutually exclusive with SessionID and + // State provides direct state for the invocation. Only valid when the agent + // is client-managed (no session store). The conversation's identity rides + // inside it ([SessionState.SessionID]). Mutually exclusive with SessionID and // SnapshotID. State *SessionState[State] `json:"state,omitempty"` } // AgentInput is the input sent to an agent during a conversation turn. type AgentInput struct { - // Detach signals that the client wishes to disconnect after this input is - // accepted. The server writes a single pending snapshot (with empty - // state), returns [AgentOutput] with that snapshot ID, and continues - // processing any already-buffered inputs in a background context. The - // pending snapshot is finalized with the cumulative final state once all - // queued inputs are processed (or the invocation is aborted via the - // abortSnapshot companion action). + // Detach signals the client will disconnect after this input is accepted. The + // server writes a pending snapshot, returns [AgentOutput] with its ID, and + // keeps processing any already-buffered inputs in the background. The pending + // snapshot is finalized with the cumulative final state once the queued inputs + // are processed (or the invocation is aborted). Detach bool `json:"detach,omitempty"` // Message is the user's input for this turn. Message *ai.Message `json:"message,omitempty"` @@ -155,10 +143,9 @@ type ToolResume struct { } // AgentMetadata is the value placed under metadata["agent"] on an agent's -// action descriptor. It exposes capability information so the Dev UI and -// other reflective callers can render the right surface (e.g. hide the -// Abort button when the configured store doesn't support it) without -// round-tripping through the reflection API. +// action descriptor. It exposes capability information so the Dev UI and other +// reflective callers can render the right surface (e.g. hide the Abort button +// when the store doesn't support it). type AgentMetadata struct { // Abortable reports whether the agent's invocations can be aborted // (true when the store implements [SnapshotAborter]). @@ -191,16 +178,10 @@ type AgentOutput[State any] struct { FinishReason AgentFinishReason `json:"finishReason,omitempty"` // Message is the last model response message from the conversation. Message *ai.Message `json:"message,omitempty"` - // SessionID is the ID of the session this invocation belongs to, - // assigned by the framework when the invocation starts. With - // server-managed state, a fresh invocation adopts the caller-supplied - // session ID (see [AgentInit.SessionID]) or mints a new one, resumed - // invocations inherit the chain's, and resuming a snapshot from before - // session IDs existed mints a fresh one. With client-managed state it - // echoes the ID carried inside the state object - // ([SessionState.SessionID]), minting one on the conversation's first - // invocation; only a session with persisted snapshots can be resumed by - // this ID. + // SessionID is the ID of the session this invocation belongs to, assigned by + // the framework when the invocation starts and stable across resumes. Pass it + // to [WithSessionID] to resume a server-managed session; with client-managed + // state it also rides inside [AgentOutput.State] ([SessionState.SessionID]). SessionID string `json:"sessionId,omitempty"` // SnapshotID is the ID of the most recent turn-end snapshot for this // invocation. Empty when no store is configured or no turn committed. When @@ -251,18 +232,13 @@ const ( type AgentStreamChunk struct { // Artifact contains a newly produced artifact. Artifact *Artifact `json:"artifact,omitempty"` - // CustomPatch is an RFC 6902 JSON Patch describing a delta applied to the - // session's custom state. The runtime emits it automatically whenever the - // agent mutates custom state (e.g. via [Session.UpdateCustom]); agents do not - // hand-craft patches. Pointers are rooted at the custom document (e.g. - // "/agentStatus"), with no "/custom" prefix. The first patch of every turn is a - // whole-document replace at the root pointer ("") that re-bases clients which - // may not share the server's baseline; subsequent patches are incremental diffs - // against the last sent value. The diff is computed on the client-facing custom - // state (after any [WithStateTransform]), so streamed deltas honor redaction and - // stay consistent with the full state in turn-end snapshots and final output. - // Apply it with [ApplyPatch] to keep a local copy of custom live as the turn - // streams. + // CustomPatch is an RFC 6902 JSON Patch describing a delta to the session's + // custom state, emitted automatically whenever the agent mutates it (e.g. via + // [Session.UpdateCustom]). Pointers are rooted at the custom document (e.g. + // "/agentStatus"), with no "/custom" prefix. The first patch of each turn is a + // whole-document replace at the root pointer ("") to re-base the client; later + // patches are incremental diffs. Apply it with [ApplyPatch] to keep a local + // copy of custom state live as the turn streams. CustomPatch JSONPatch `json:"customPatch,omitempty"` // ModelChunk contains generation tokens from the model. ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` @@ -284,19 +260,15 @@ type Artifact struct { Parts []*ai.Part `json:"parts"` } -// GetSnapshotRequest is the input for an agent's getSnapshot companion -// action, registered under the agent's name (action type agent-snapshot) -// when the agent has a session store configured. The action is intended -// for Dev UI and client-side reconnect flows. It returns the stored -// [SessionSnapshot], with [WithStateTransform] applied to its state if -// configured. +// GetSnapshotRequest is the input for an agent's getSnapshot companion action, +// available when the agent has a session store configured. Intended for Dev UI +// and client reconnect flows, it returns the stored [SessionSnapshot] with +// [WithStateTransform] applied to its state if configured. // -// At least one of SnapshotID or SessionID must be set; they are not -// mutually exclusive. SnapshotID fetches a specific snapshot; SessionID -// alone fetches the session's latest snapshot (via the store's -// [SnapshotReader.GetLatestSnapshot], whatever its status). When both are -// set, the fetched snapshot must belong to that session, or the request -// is rejected. +// At least one of SnapshotID or SessionID must be set. SnapshotID fetches a +// specific snapshot; SessionID alone fetches the session's latest snapshot +// (whatever its status). When both are set, the fetched snapshot must belong +// to that session, or the request is rejected. type GetSnapshotRequest struct { // SessionID identifies the session whose latest snapshot to fetch. // Optional when SnapshotID is given. The latest snapshot is the session's @@ -336,15 +308,10 @@ type JSONPatchOperation struct { // Op is the operation to perform. Op JSONPatchOp `json:"op"` // Path is a JSON Pointer (RFC 6901) to the target location, e.g. "/agentStatus". - // The empty pointer "" refers to the whole document. It must always be present on - // the wire (a whole-document replace carries path ""), so it is not omitted when - // empty. + // The empty pointer "" refers to the whole document. Path string `json:"path"` - // Value is the operand for "add", "replace", and "test". It is not omitted when - // null so an explicit null operand survives the wire (omitempty cannot tell a - // null operand from an absent one, and dropping it makes a peer applier set the - // member to undefined or remove it instead of null); for "remove", "move", and - // "copy" it is null and ignored. + // Value is the operand for "add", "replace", and "test"; an explicit null is a + // valid operand. Ignored for "remove", "move", and "copy". Value any `json:"value"` } @@ -367,12 +334,10 @@ type SessionSnapshot[State any] struct { // informational lineage (for debugging and UI history trees) and plays // no part in resolving a session's latest snapshot. ParentID string `json:"parentId,omitempty"` - // SessionID is the ID of the session this snapshot belongs to. Assigned - // by the agent framework when the conversation's first invocation starts - // and stamped on every later snapshot in the chain, including across - // resumed invocations. Stores preserve it across rewrites; rows written - // without one (data from before session IDs existed) belong to no - // session. + // SessionID is the ID of the session this snapshot belongs to. Assigned by the + // framework on the conversation's first invocation and stamped on every later + // snapshot in the chain, across resumed invocations; stores preserve it across + // rewrites. SessionID string `json:"sessionId,omitempty"` // SnapshotID is the unique identifier for this snapshot (UUID). SnapshotID string `json:"snapshotId"` @@ -401,10 +366,9 @@ type SessionState[State any] struct { // Does NOT include prompt-rendered messages — those are rendered fresh each turn. Messages []*ai.Message `json:"messages,omitempty"` // SessionID is the ID of the session (conversation) this state belongs to. - // Framework-owned: assigned when the conversation's first invocation - // starts and re-stamped on outbound state, so client-managed callers can - // round-trip the state object opaquely without tracking a separate - // identifier. For server-managed agents the snapshot row's + // Framework-owned: assigned on the conversation's first invocation and + // re-stamped on outbound state, so client-managed callers can round-trip the + // state object opaquely. For server-managed agents the snapshot row's // [SessionSnapshot.SessionID] is canonical and this field mirrors it. SessionID string `json:"sessionId,omitempty"` } diff --git a/go/ai/exp/jsonpatch.go b/go/ai/exp/jsonpatch.go index 4806cb27a8..f0aa46c1aa 100644 --- a/go/ai/exp/jsonpatch.go +++ b/go/ai/exp/jsonpatch.go @@ -118,17 +118,10 @@ func diffWalk(from, to any, pointer string, patch *JSONPatch) { // value. The input is not mutated; a normalized clone is patched and returned. // Operating on the root pointer ("") replaces or removes the whole document. // -// Apply is lenient to keep streaming robust: an add or replace whose parent +// It is lenient to keep streaming robust: an add or replace whose parent // container is missing initializes it as an object, and a remove or replace of // a missing member is a no-op. A test operation is honored and returns an error // on mismatch. Other unknown operations also return an error. -// -// Apply diverges from the JS reference applier (applyPatch in json-patch.ts) -// only on inputs [Diff] never emits: an add or replace at an out-of-range array -// index is a no-op here (JS splices/grows the array, back-filling with null), -// and a test against a missing path may pass here where JS throws. The runtime -// applies only Diff output, which is always in range and never a test, so the -// server and a JS client agree on every patch the streaming protocol produces. func ApplyPatch(document any, patch JSONPatch) (any, error) { return applyOps(cloneJSON(normalizeJSON(document)), patch) } diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index f310b40aa6..88f67615b6 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -36,18 +36,11 @@ import ( // on the local filesystem. Each snapshot is written to its own file named // ".json" in the configured directory. // -// The store is safe for concurrent use within a single process. It does NOT -// coordinate writes with other processes that may share the same directory: -// the only synchronization is the per-instance mutex. If multiple processes -// write to the same directory the last successful rename wins; readers may -// also observe a brief window during which a snapshot is still being written -// by another process (the rename itself is atomic, but cross-process -// linearization is not guaranteed). -// -// [FileSessionStore.OnSnapshotStatusChange] uses in-process channels and only -// reflects status transitions caused by calls on this store instance. -// External writes to the directory and writes from other processes are not -// observed. +// The store is safe for concurrent use within a single process, but does NOT +// coordinate with other processes sharing the directory: the last successful +// rename wins, and a reader may briefly observe a snapshot another process is +// still writing. [FileSessionStore.OnSnapshotStatusChange] likewise reflects +// only status changes made through this instance. type FileSessionStore[State any] struct { // mu serializes the read-modify-write paths and the subscriber bookkeeping. // File I/O happens under the lock; this matches the simplicity of @@ -84,10 +77,9 @@ func (s *FileSessionStore[State]) GetSnapshot(_ context.Context, snapshotID stri return s.readLocked(snapshotID) } -// SaveSnapshot atomically reads, applies fn, and persists. See the -// [exp.SnapshotWriter] interface for the full contract; this implementation -// satisfies it by holding s.mu for the entire read-modify-write so fn is -// called exactly once per SaveSnapshot call. +// SaveSnapshot atomically reads, applies fn, and persists. See +// [exp.SnapshotWriter] for the full contract; this implementation calls fn +// exactly once per call. func (s *FileSessionStore[State]) SaveSnapshot( _ context.Context, id string, @@ -150,16 +142,10 @@ type snapshotHeader struct { // regardless of status, per the [exp.SnapshotReader.GetLatestSnapshot] // contract. // -// Recency is judged by file mtime, which for snapshots written by this -// package advances with [exp.SessionSnapshot.UpdatedAt] (each save -// creates a fresh temp file and renames it into place); if a file is -// touched externally, mtime wins. The scan walks files newest first and -// stops at the first row for the session, so resolving the most recently -// active session costs one read in the common case. Only header fields -// are decoded per candidate (the winner is the only full parse), the -// store lock is held per file rather than across the whole scan, and a -// file that vanishes mid-scan or fails to parse is skipped so one -// corrupted row cannot poison every session in the directory. +// Recency is judged by file mtime, which for snapshots written by this package +// advances with [exp.SessionSnapshot.UpdatedAt]; if a file is touched +// externally, mtime wins. A file that fails to parse or vanishes mid-scan is +// skipped, so one corrupted row cannot hide every other session. func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { if sessionID == "" { return nil, errors.New("FileSessionStore: session ID is empty") diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go index ff25796469..de56dd9bf3 100644 --- a/go/ai/exp/localstore/inmemory.go +++ b/go/ai/exp/localstore/inmemory.go @@ -69,9 +69,7 @@ func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID // GetLatestSnapshot returns the session's most recently updated snapshot // regardless of status, per the [exp.SnapshotReader.GetLatestSnapshot] // contract. Ties on UpdatedAt are broken by SnapshotID so resolution is -// deterministic. The scan runs under the read lock, so the stored rows -// (which other calls mutate in place) never escape it; the winner is -// returned as a deep copy. +// deterministic. The returned snapshot is a deep copy. func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { if sessionID == "" { return nil, errors.New("InMemorySessionStore: session ID is empty") @@ -112,10 +110,9 @@ func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotI return snap.Status, nil } -// SaveSnapshot atomically reads, applies fn, and persists. See the -// [exp.SnapshotWriter] interface for the full contract; this implementation -// satisfies it by holding s.mu for the entire read-modify-write so fn is -// called exactly once per SaveSnapshot call. +// SaveSnapshot atomically reads, applies fn, and persists. See +// [exp.SnapshotWriter] for the full contract; this implementation calls fn +// exactly once per call. func (s *InMemorySessionStore[State]) SaveSnapshot( _ context.Context, id string, diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 377443648f..617f856221 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -30,20 +30,16 @@ type AgentOption[State any] interface { applyAgent(*agentOptions[State]) error } -// StateTransform rewrites session state on its way out to a client. It -// is applied to the State returned by the getSnapshot companion action -// and to [AgentOutput.State] when state is client-managed (no store). -// It is not applied to state persisted in the store or to state passed -// to the user agent function. +// StateTransform rewrites session state on its way out to a client: it is +// applied to the state returned by the getSnapshot companion action and to +// [AgentOutput.State] for client-managed agents, but not to state persisted in +// the store or passed to the agent function. Typical uses are PII redaction and +// stripping secrets. // -// ctx is the request or invocation context: cancellation, deadlines, -// and context-scoped values (e.g. the caller's identity for RBAC-aware -// redaction) flow through here. -// -// state is a fresh deep copy made for this call: the transform owns it -// and may mutate in place, return a new pointer, or return nil to omit -// state from the response entirely. Do not retain the pointer past the -// call; the framework drops its reference after the transform returns. +// state is a fresh deep copy the transform owns: it may mutate in place, return +// a new pointer, or return nil to omit state from the response. ctx is the +// request or invocation context, carrying deadlines and values such as the +// caller's identity for RBAC-aware redaction. type StateTransform[State any] = func(ctx context.Context, state *SessionState[State]) *SessionState[State] type agentOptions[State any] struct { @@ -82,20 +78,15 @@ func WithSessionStore[State any](store SessionStore[State]) AgentOption[State] { return &agentOptions[State]{store: store} } -// WithStateTransform registers a transform applied to session state on -// its way out to a client via the getSnapshot companion action or via -// [AgentOutput.State] when state is client-managed. Typical use is PII -// redaction or stripping secrets. The transform is not applied to state -// persisted in the store or to state passed to the user agent function. +// WithStateTransform registers a [StateTransform] applied to session state on +// its way out to a client (via the getSnapshot companion action or +// [AgentOutput.State]). Typical uses are PII redaction and stripping secrets. func WithStateTransform[State any](transform StateTransform[State]) AgentOption[State] { return &agentOptions[State]{transform: transform} } -// WithDescription sets a human-readable description of the agent. It is -// stored on the agent action's descriptor (read back via [Agent.Desc] and -// surfaced in the Dev UI's action listing), the same place every other -// primitive carries its description, so reflective tooling can render it -// without a separate field. +// WithDescription sets a human-readable description of the agent, stored on its +// action descriptor (read back via [Agent.Desc] and shown in the Dev UI). func WithDescription[State any](description string) AgentOption[State] { return &agentOptions[State]{description: description} } @@ -149,15 +140,13 @@ func (o *invocationOptions[State]) applyInvocation(opts *invocationOptions[State return nil } -// WithState sets the initial state for the invocation. -// Use this for client-managed state where the client sends state directly. -// The conversation's identity rides inside the state object -// ([SessionState.SessionID]): the framework mints one on the -// conversation's first invocation and echoes it on the output state, so -// resending the state keeps the identity without tracking a separate -// field. The framework deep-copies the state when the invocation starts, -// so the caller keeps ownership of the object it passed and may reuse it -// freely. Mutually exclusive with [WithSessionID] and [WithSnapshotID]. +// WithState sets the initial state for the invocation, for client-managed +// agents where the client sends state directly. The conversation's identity +// rides inside the state ([SessionState.SessionID]); the framework mints it on +// the first invocation and echoes it on the output, so resending the state +// keeps the identity. The state is deep-copied at the start, so the caller may +// reuse the object freely. Mutually exclusive with [WithSessionID] and +// [WithSnapshotID]. func WithState[State any](state *SessionState[State]) InvocationOption[State] { return &invocationOptions[State]{state: state} } @@ -171,32 +160,22 @@ func WithSnapshotID[State any](id string) InvocationOption[State] { return &invocationOptions[State]{snapshotID: id} } -// WithSessionID resumes the session (conversation) with the given ID -// from its latest snapshot: the most recently updated one that is not a -// failed/aborted dead end (see [SnapshotReader.GetLatestSnapshot]). Use -// this when the caller tracks the conversation rather than individual -// snapshots; the session ID is assigned when the conversation's first -// invocation starts (see [AgentOutput.SessionID]) and stays stable -// across resumed invocations. -// -// Only valid when the agent is server-managed (a session store is -// configured) and therefore mutually exclusive with [WithState]: a -// client-managed conversation carries its identity inside the state -// object ([SessionState.SessionID]) instead. Combined with -// [WithSnapshotID], the snapshot picks the exact resume point and the -// session ID is validated against it, so an invocation never silently -// continues a conversation other than the one named. +// WithSessionID resumes the conversation with the given ID from its latest +// snapshot (see [SnapshotReader.GetLatestSnapshot]). Use it when the caller +// tracks the conversation rather than individual snapshots; the ID is assigned +// on the conversation's first invocation (see [AgentOutput.SessionID]) and +// stays stable across resumes. // -// A pending latest snapshot means a detached invocation is still -// running; the resume is rejected so it cannot race the background -// work. Wait for the snapshot to finalize, or abort it. If the -// session's history was forked (an earlier snapshot was resumed again, -// or two invocations resumed the session concurrently), the most -// recently updated branch wins; use [WithSnapshotID] to continue a -// specific branch instead. +// Valid only for server-managed agents, and so mutually exclusive with +// [WithState] (a client-managed conversation carries its identity inside the +// state). Combined with [WithSnapshotID], the snapshot picks the resume point +// and the session ID is validated against it. // -// Passing an empty ID is an error rather than a no-op, so an unset -// [AgentOutput.SessionID] cannot silently start a fresh conversation. +// The resume is rejected if the latest snapshot is a failed, aborted, or +// pending dead end; a pending tip means a detached invocation is still running, +// so wait for it to finalize or abort it. If history was forked, the most +// recently updated branch wins; use [WithSnapshotID] for a specific branch. An +// empty ID is an error, not a no-op. func WithSessionID[State any](id string) InvocationOption[State] { return &invocationOptions[State]{sessionID: id, sessionIDSet: true} } diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 5915769a9f..44bd98783c 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -48,26 +48,18 @@ type SnapshotReader[State any] interface { GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) // GetLatestSnapshot returns the session's most recently updated - // snapshot as a full row (the runtime loads its state to resume), - // whatever its status: a pending, failed, or aborted row is returned - // like any other. Returns nil if the session has no rows (unknown - // session ID), and an error if sessionID is empty. + // snapshot, whatever its status: a pending, failed, or aborted row is + // returned like any other, and the caller applies its own policy. + // Returns nil if the session has no rows, and an error if sessionID is + // empty. // - // "Most recently updated" means the greatest - // [SessionSnapshot.UpdatedAt], falling back to CreatedAt on rows that - // lack one; ties may be broken arbitrarily but deterministically (e.g. - // by SnapshotID). The latest row is returned unconditionally, so the - // two callers can each apply their own policy: the getSnapshot - // companion action surfaces it verbatim (a client reconnecting wants to - // see a pending/failed tip), and the session-ID resume path rejects a - // failed/aborted/pending tip rather than continuing from a dead end. - // - // The contract is a plain max-timestamp lookup, so stores can implement - // it as a single indexed query (e.g. WHERE sessionId = ? ORDER BY - // updatedAt DESC LIMIT 1). ParentID is informational lineage and plays - // no part in resolution: when a session's history was forked by - // re-resuming an earlier snapshot, the most recently updated branch - // simply wins. + // "Most recently updated" means the greatest [SessionSnapshot.UpdatedAt], + // falling back to CreatedAt on rows that lack one; break ties + // deterministically (e.g. by SnapshotID). This is a plain max-timestamp + // lookup, implementable as a single indexed query (e.g. WHERE sessionId = ? + // ORDER BY updatedAt DESC LIMIT 1). ParentID is informational lineage and + // plays no part in resolution: when history forks, the most recently + // updated branch wins. GetLatestSnapshot(ctx context.Context, sessionID string) (*SessionSnapshot[State], error) } @@ -84,9 +76,7 @@ type SnapshotWriter[State any] interface { // - SessionID: the ID of the session (chain of snapshots) the row // belongs to: preserved from the existing row on update (a row's // session never changes once set), otherwise taken from fn's row - // as-is. Stores never mint or infer session IDs; the agent - // runtime assigns one when an invocation starts and stamps it on - // every row it writes. + // as-is. Stores never mint or infer session IDs. // - CreatedAt: stamped to the wall clock on first write; preserved // from the existing row on update. // - UpdatedAt: stamped to the wall clock on every commit. @@ -113,23 +103,12 @@ type SnapshotWriter[State any] interface { ) (*SessionSnapshot[State], error) } -// SnapshotAborter is the optional capability layered on [SessionStore] -// that lets an agent's invocations be aborted. It bundles the two -// methods that must be implemented together for the abort lifecycle to -// function: -// -// - [SnapshotAborter.AbortSnapshot] flips a pending snapshot's status -// to aborted (typically called by the abortSnapshot companion -// action or directly by a Go caller holding the store). -// -// - [SnapshotAborter.OnSnapshotStatusChange] lets the agent runtime -// observe the flip without polling, so it can promptly cancel the -// work context. -// -// They are bundled because neither is useful alone: flipping status -// with no observer means the running fn never learns it was aborted; -// observing without a way to trigger the flip means no abort can -// happen. +// SnapshotAborter is the optional capability layered on [SessionStore] that +// lets an agent's invocations be aborted. The two methods work together: +// [SnapshotAborter.AbortSnapshot] flips a pending snapshot's status to aborted, +// and [SnapshotAborter.OnSnapshotStatusChange] lets the agent runtime observe +// the flip without polling so it can promptly cancel the work context. A store +// must implement both or neither. type SnapshotAborter interface { // AbortSnapshot atomically transitions a snapshot from // [SnapshotStatusPending] to [SnapshotStatusAborted] and returns the @@ -139,9 +118,8 @@ type SnapshotAborter interface { // callers can distinguish "not found" from a real error. // // Implementations must perform the read-and-write atomically (e.g., a - // transaction or a compare-and-swap). The agent's abortSnapshot - // action and finalizer rely on this to avoid a pending row being - // clobbered by a racing terminal write. + // transaction or a compare-and-swap) so a racing terminal write cannot + // clobber the pending row. AbortSnapshot(ctx context.Context, snapshotID string) (SnapshotStatus, error) // OnSnapshotStatusChange returns a channel that yields the snapshot's @@ -150,9 +128,8 @@ type SnapshotAborter interface { // cancelled. If the snapshot does not exist when the subscription is // established, the channel is closed without yielding a value. // - // Implementations may push changes from a transaction log, a CDC - // feed, or fall back to polling internally; the contract just spares - // callers the choice. + // Implementations may push changes from a transaction log or CDC feed, + // or poll internally. OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan SnapshotStatus } @@ -332,19 +309,11 @@ type Session[State any] struct { onCustomChange func() } -// SessionID returns the ID of the session this conversation belongs to. -// The agent runtime settles it before the agent function runs: a fresh -// invocation mints a new ID, server-managed resumes inherit the chain's -// (a snapshot from before session IDs existed gets a fresh one), and a -// client-managed invocation keeps the ID carried inside the state object -// it was given ([SessionState.SessionID]), minting one if absent. -// -// The ID is stable for the lifetime of the invocation; it lives in -// [SessionState.SessionID], so it is stamped on every snapshot the -// invocation persists and rides inside the state returned to -// client-managed callers. It is safe to use as a key for external -// resources tied to the conversation, including from code that -// retrieves the session via [SessionFromContext]. +// SessionID returns the ID of the session this conversation belongs to. The +// agent runtime settles it before the agent function runs and keeps it stable +// for the invocation's lifetime, stamping it on every snapshot persisted. It is +// safe to use as a key for external resources tied to the conversation, +// including from code that retrieves the session via [SessionFromContext]. func (s *Session[State]) SessionID() string { // Written once at construction, before fn runs and before the session // is shared, then never mutated; safe to read without holding mu. diff --git a/go/core/schemas.config b/go/core/schemas.config index 1cea9d070d..d6045ac8da 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1210,13 +1210,11 @@ AgentInput is the input sent to an agent during a conversation turn. . AgentInput.detach doc -Detach signals that the client wishes to disconnect after this input is -accepted. The server writes a single pending snapshot (with empty -state), returns [AgentOutput] with that snapshot ID, and continues -processing any already-buffered inputs in a background context. The -pending snapshot is finalized with the cumulative final state once all -queued inputs are processed (or the invocation is aborted via the -abortSnapshot companion action). +Detach signals the client will disconnect after this input is accepted. The +server writes a pending snapshot, returns [AgentOutput] with its ID, and +keeps processing any already-buffered inputs in the background. The pending +snapshot is finalized with the cumulative final state once the queued inputs +are processed (or the invocation is aborted). . AgentInput.message type *ai.Message @@ -1276,20 +1274,13 @@ Sending no fields starts a fresh invocation with empty state. . AgentInit.sessionId doc -SessionID identifies the session (conversation) to resume or start. -Only valid when the agent is server-managed (a session store is -configured); mutually exclusive with State (a client-managed -conversation carries its identity inside [SessionState.SessionID]). -Alone, it resumes the session from its latest snapshot: the most -recently updated row, whatever its status. If that row is a failed, -aborted, or still-pending dead end the resume is rejected (pass -SnapshotID to continue from a specific earlier point); if the session's -history was forked by resuming an earlier snapshot again, the most -recently updated branch wins. If the session has no snapshots yet, a -brand-new conversation is started under this caller-chosen ID, and -every snapshot it persists carries it. Combined with SnapshotID, it -asserts which session the snapshot belongs to, and a mismatch is -rejected. +SessionID identifies the session (conversation) to resume or start. Only +valid when the agent is server-managed (a session store is configured); +mutually exclusive with State. Alone, it resumes the session's latest +snapshot, rejected if that snapshot is a failed, aborted, or pending dead +end. If the session has no snapshots yet, a fresh conversation starts under +this caller-chosen ID. Combined with SnapshotID, it asserts the snapshot +belongs to that session. . AgentInit.snapshotId doc @@ -1300,12 +1291,9 @@ session. Mutually exclusive with State. . AgentInit.state doc -State provides direct state for the invocation. Only valid when the -agent is client-managed (no session store). The conversation's -identity rides inside it ([SessionState.SessionID]): the framework -mints one on the conversation's first invocation and echoes it on the -output state, so resending the state object keeps the identity without -tracking a separate field. Mutually exclusive with SessionID and +State provides direct state for the invocation. Only valid when the agent +is client-managed (no session store). The conversation's identity rides +inside it ([SessionState.SessionID]). Mutually exclusive with SessionID and SnapshotID. . @@ -1348,16 +1336,10 @@ It wraps AgentResult with framework-managed fields. . AgentOutput.sessionId doc -SessionID is the ID of the session this invocation belongs to, -assigned by the framework when the invocation starts. With -server-managed state, a fresh invocation adopts the caller-supplied -session ID (see [AgentInit.SessionID]) or mints a new one, resumed -invocations inherit the chain's, and resuming a snapshot from before -session IDs existed mints a fresh one. With client-managed state it -echoes the ID carried inside the state object -([SessionState.SessionID]), minting one on the conversation's first -invocation; only a session with persisted snapshots can be resumed by -this ID. +SessionID is the ID of the session this invocation belongs to, assigned by +the framework when the invocation starts and stable across resumes. Pass it +to [WithSessionID] to resume a server-managed session; with client-managed +state it also rides inside [AgentOutput.State] ([SessionState.SessionID]). . AgentOutput.snapshotId doc @@ -1472,9 +1454,7 @@ Op is the operation to perform. JsonPatchOperation.path noomitempty JsonPatchOperation.path doc Path is a JSON Pointer (RFC 6901) to the target location, e.g. "/agentStatus". -The empty pointer "" refers to the whole document. It must always be present on -the wire (a whole-document replace carries path ""), so it is not omitted when -empty. +The empty pointer "" refers to the whole document. . JsonPatchOperation.from doc @@ -1484,11 +1464,8 @@ From is a JSON Pointer to the source location; required for "move" and "copy". JsonPatchOperation.value type any JsonPatchOperation.value noomitempty JsonPatchOperation.value doc -Value is the operand for "add", "replace", and "test". It is not omitted when -null so an explicit null operand survives the wire (omitempty cannot tell a -null operand from an absent one, and dropping it makes a peer applier set the -member to undefined or remove it instead of null); for "remove", "move", and -"copy" it is null and ignored. +Value is the operand for "add", "replace", and "test"; an explicit null is a +valid operand. Ignored for "remove", "move", and "copy". . JsonPatch pkg ai/exp @@ -1518,18 +1495,13 @@ ModelChunk contains generation tokens from the model. AgentStreamChunk.customPatch type JSONPatch AgentStreamChunk.customPatch doc -CustomPatch is an RFC 6902 JSON Patch describing a delta applied to the -session's custom state. The runtime emits it automatically whenever the -agent mutates custom state (e.g. via [Session.UpdateCustom]); agents do not -hand-craft patches. Pointers are rooted at the custom document (e.g. -"/agentStatus"), with no "/custom" prefix. The first patch of every turn is a -whole-document replace at the root pointer ("") that re-bases clients which -may not share the server's baseline; subsequent patches are incremental diffs -against the last sent value. The diff is computed on the client-facing custom -state (after any [WithStateTransform]), so streamed deltas honor redaction and -stay consistent with the full state in turn-end snapshots and final output. -Apply it with [ApplyPatch] to keep a local copy of custom live as the turn -streams. +CustomPatch is an RFC 6902 JSON Patch describing a delta to the session's +custom state, emitted automatically whenever the agent mutates it (e.g. via +[Session.UpdateCustom]). Pointers are rooted at the custom document (e.g. +"/agentStatus"), with no "/custom" prefix. The first patch of each turn is a +whole-document replace at the root pointer ("") to re-base the client; later +patches are incremental diffs. Apply it with [ApplyPatch] to keep a local +copy of custom state live as the turn streams. . AgentStreamChunk.artifact doc @@ -1588,10 +1560,9 @@ and server. It contains only the data needed for conversation continuity. SessionState.sessionId doc SessionID is the ID of the session (conversation) this state belongs to. -Framework-owned: assigned when the conversation's first invocation -starts and re-stamped on outbound state, so client-managed callers can -round-trip the state object opaquely without tracking a separate -identifier. For server-managed agents the snapshot row's +Framework-owned: assigned on the conversation's first invocation and +re-stamped on outbound state, so client-managed callers can round-trip the +state object opaquely. For server-managed agents the snapshot row's [SessionSnapshot.SessionID] is canonical and this field mirrors it. . @@ -1628,12 +1599,10 @@ SnapshotID is the unique identifier for this snapshot (UUID). . SessionSnapshot.sessionId doc -SessionID is the ID of the session this snapshot belongs to. Assigned -by the agent framework when the conversation's first invocation starts -and stamped on every later snapshot in the chain, including across -resumed invocations. Stores preserve it across rewrites; rows written -without one (data from before session IDs existed) belong to no -session. +SessionID is the ID of the session this snapshot belongs to. Assigned by the +framework on the conversation's first invocation and stamped on every later +snapshot in the chain, across resumed invocations; stores preserve it across +rewrites. . SessionSnapshot.parentId doc @@ -1793,19 +1762,15 @@ error. GetSnapshotRequest pkg ai/exp GetSnapshotRequest doc -GetSnapshotRequest is the input for an agent's getSnapshot companion -action, registered under the agent's name (action type agent-snapshot) -when the agent has a session store configured. The action is intended -for Dev UI and client-side reconnect flows. It returns the stored -[SessionSnapshot], with [WithStateTransform] applied to its state if -configured. +GetSnapshotRequest is the input for an agent's getSnapshot companion action, +available when the agent has a session store configured. Intended for Dev UI +and client reconnect flows, it returns the stored [SessionSnapshot] with +[WithStateTransform] applied to its state if configured. -At least one of SnapshotID or SessionID must be set; they are not -mutually exclusive. SnapshotID fetches a specific snapshot; SessionID -alone fetches the session's latest snapshot (via the store's -[SnapshotReader.GetLatestSnapshot], whatever its status). When both are -set, the fetched snapshot must belong to that session, or the request -is rejected. +At least one of SnapshotID or SessionID must be set. SnapshotID fetches a +specific snapshot; SessionID alone fetches the session's latest snapshot +(whatever its status). When both are set, the fetched snapshot must belong +to that session, or the request is rejected. . GetSnapshotRequest.snapshotId doc @@ -1855,10 +1820,9 @@ AgentMetadata pkg ai/exp AgentMetadata doc AgentMetadata is the value placed under metadata["agent"] on an agent's -action descriptor. It exposes capability information so the Dev UI and -other reflective callers can render the right surface (e.g. hide the -Abort button when the configured store doesn't support it) without -round-tripping through the reflection API. +action descriptor. It exposes capability information so the Dev UI and other +reflective callers can render the right surface (e.g. hide the Abort button +when the store doesn't support it). . AgentMetadata.stateManagement doc From e2341f3ea1c36cd4f14cd2bb2a350befd0f52f4b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 18 Jun 2026 15:50:20 -0700 Subject: [PATCH 119/141] refactor(go): drop core/x/session for the ai/exp session Remove the legacy generic typed-state session package and its dependents. It is superseded by the agent-oriented, snapshot-backed session in go/ai/exp, so the standalone state container, its Firestore store, and the cart sample no longer have a place. Preserve {{@state}} prompt injection against the new session. go/ai cannot import go/ai/exp (that package imports go/ai), so the type-erased state hook moves down to internal/base: exp.NewSessionContext publishes the session's custom state through base.WithPromptState, and prompt rendering reads it via base.PromptStateFromContext. The getter is evaluated lazily so templates see the latest custom state. The agent runtime already attaches the session to context, so agent prompts get {{@state}} with no extra wiring. Removed: - go/core/x/session (package and tests) - go/plugins/firebase/x/session_store.go (Firestore impl of the old Store[S]) and its test; the shared firestoreOptions stay for the stream manager, minus the now-orphaned applySessionStore - go/samples/session and its README section A Firestore backend for the new snapshot store interface is follow-up work; Firebase session persistence is unavailable until then. --- go/README.md | 32 - go/ai/exp/session.go | 9 +- go/ai/exp/session_test.go | 78 ++ go/ai/prompt.go | 3 +- go/ai/prompt_test.go | 63 +- go/core/x/session/session.go | 358 --------- go/core/x/session/session_test.go | 782 -------------------- go/internal/base/prompt_state.go | 44 ++ go/plugins/firebase/x/option.go | 5 - go/plugins/firebase/x/session_store.go | 184 ----- go/plugins/firebase/x/session_store_test.go | 439 ----------- go/samples/session/main.go | 129 ---- 12 files changed, 155 insertions(+), 1971 deletions(-) create mode 100644 go/ai/exp/session_test.go delete mode 100644 go/core/x/session/session.go delete mode 100644 go/core/x/session/session_test.go create mode 100644 go/internal/base/prompt_state.go delete mode 100644 go/plugins/firebase/x/session_store.go delete mode 100644 go/plugins/firebase/x/session_store_test.go delete mode 100644 go/samples/session/main.go diff --git a/go/README.md b/go/README.md index 598ecf3aad..7884360937 100644 --- a/go/README.md +++ b/go/README.md @@ -676,37 +676,6 @@ Clients receive a stream ID in the `X-Genkit-Stream-Id` header and can reconnect [See full example](samples/durable-streaming) -### Sessions - -Maintain typed state across multiple requests and throughout generation including tools: - -```go -import "github.com/firebase/genkit/go/core/x/session" - -type CartState struct { - Items []string `json:"items"` -} - -store := session.NewInMemoryStore[CartState]() - -genkit.DefineFlow(g, "manageCart", func(ctx context.Context, input string) (string, error) { - sess, err := session.Load(ctx, store, "session-id") - if err != nil { - sess, _ = session.New(ctx, - session.WithID[CartState]("session-id"), - session.WithStore(store), - session.WithInitialState(CartState{}), - ) - } - ctx = session.NewContext(ctx, sess) - - // Tools can now access session state via session.FromContext[CartState](ctx) - return genkit.GenerateText(ctx, g, ai.WithPrompt(input), ai.WithTools(myTools...)) -}) -``` - -[See full example](samples/session) - --- ## Samples @@ -724,7 +693,6 @@ Explore working examples to see Genkit in action: | [basic-middleware/skills](samples/basic-middleware/skills) | On-demand loadable `SKILL.md` personas | | [prompts-embed](samples/prompts-embed) | Embed prompts in your binary | | [durable-streaming](samples/durable-streaming) | Reconnectable streams with replay | -| [session](samples/session) | Stateful flows with typed session data | --- diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 44bd98783c..35e5a248bb 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -465,8 +465,15 @@ func (s *Session[State]) copyStateLocked() SessionState[State] { var sessionCtxKey = base.NewContextKey[any]() // NewSessionContext returns a new context with the session attached. +// +// It also publishes a type-erased view of the session's custom state so prompt +// rendering can inject it into templates as {{@state}}. go/ai cannot import this +// package (this package imports go/ai), so the custom state is exposed through a +// getter in internal/base, evaluated at render time so templates see the latest +// values. func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { - return sessionCtxKey.NewContext(ctx, s) + ctx = sessionCtxKey.NewContext(ctx, s) + return base.WithPromptState(ctx, func() any { return s.customJSON() }) } // SessionFromContext retrieves the current session from context. diff --git a/go/ai/exp/session_test.go b/go/ai/exp/session_test.go new file mode 100644 index 0000000000..92853905d2 --- /dev/null +++ b/go/ai/exp/session_test.go @@ -0,0 +1,78 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "reflect" + "testing" + + "github.com/firebase/genkit/go/internal/base" +) + +// TestNewSessionContextPublishesPromptState verifies that attaching a session to +// a context also exposes its custom state through internal/base, which is how +// ai.prompt injects {{@state}} into templates without importing this package. +func TestNewSessionContextPublishesPromptState(t *testing.T) { + s := &Session[map[string]any]{ + state: SessionState[map[string]any]{ + Custom: map[string]any{ + "name": "Alice", + "preferences": map[string]any{"theme": "dark"}, + }, + }, + } + + ctx := NewSessionContext(context.Background(), s) + + got := base.PromptStateFromContext(ctx) + want := map[string]any{ + "name": "Alice", + "preferences": map[string]any{"theme": "dark"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("PromptStateFromContext() = %#v, want %#v", got, want) + } +} + +// TestPromptStateReflectsLatestCustom verifies the published state getter is +// evaluated lazily, so a template rendered later sees custom-state mutations +// made after the context was built. +func TestPromptStateReflectsLatestCustom(t *testing.T) { + s := &Session[map[string]any]{ + state: SessionState[map[string]any]{Custom: map[string]any{"n": float64(1)}}, + } + ctx := NewSessionContext(context.Background(), s) + + s.UpdateCustom(func(map[string]any) map[string]any { + return map[string]any{"n": float64(2)} + }) + + got := base.PromptStateFromContext(ctx) + want := map[string]any{"n": float64(2)} + if !reflect.DeepEqual(got, want) { + t.Errorf("PromptStateFromContext() = %#v, want %#v", got, want) + } +} + +// TestPromptStateNilWithoutSession verifies that no state is published when no +// session is attached to the context. +func TestPromptStateNilWithoutSession(t *testing.T) { + if got := base.PromptStateFromContext(context.Background()); got != nil { + t.Errorf("PromptStateFromContext() = %#v, want nil", got) + } +} diff --git a/go/ai/prompt.go b/go/ai/prompt.go index a73cb4a7e7..61b41d605e 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -32,7 +32,6 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/logger" - "github.com/firebase/genkit/go/core/x/session" "github.com/firebase/genkit/go/internal/base" "github.com/google/dotprompt/go/dotprompt" "github.com/invopop/jsonschema" @@ -631,7 +630,7 @@ func renderDotpromptToMessages(ctx context.Context, promptFn dotprompt.PromptFun maps.Copy(templateContext, actionCtx) // Inject session state if available (accessible via {{@state.field}} in templates) - if state := session.StateFromContext(ctx); state != nil { + if state := base.PromptStateFromContext(ctx); state != nil { templateContext["state"] = state } diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index 625885d1b8..9fb10bc1c6 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -26,7 +26,6 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" - "github.com/firebase/genkit/go/core/x/session" "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/internal/registry" "github.com/google/go-cmp/cmp" @@ -2329,19 +2328,16 @@ func TestPromptExecuteStream(t *testing.T) { }) } -// TestSessionStateInjection tests that session state is automatically injected -// into prompt templates and accessible via {{@state.field}} syntax. +// TestSessionStateInjection tests that state attached to the context via +// base.WithPromptState is automatically injected into prompt templates and +// accessible via {{@state.field}} syntax. Sessions attach their custom state +// this way (see exp.NewSessionContext), decoupled so go/ai needs no dependency +// on the session package. func TestSessionStateInjection(t *testing.T) { r := registry.New() ConfigureFormats(r) - // Define a test state type - type UserState struct { - Name string `json:"name"` - Preferences map[string]string `json:"preferences"` - } - - t.Run("session state accessible in prompt template", func(t *testing.T) { + t.Run("state accessible in prompt template", func(t *testing.T) { var capturedPrompt string testModel := DefineModel(r, "test/sessionStateModel", &ModelOptions{ @@ -2360,33 +2356,27 @@ func TestSessionStateInjection(t *testing.T) { WithPrompt("Hello {{@state.name}}, your theme is {{@state.preferences.theme}}"), ) - // Create a session with state - ctx := context.Background() - sess, err := session.New(ctx, session.WithInitialState(UserState{ - Name: "Alice", - Preferences: map[string]string{"theme": "dark"}, - })) - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - - // Attach session to context - ctx = session.NewContext(ctx, sess) + // Attach state to the context the way a session does. + ctx := base.WithPromptState(context.Background(), func() any { + return map[string]any{ + "name": "Alice", + "preferences": map[string]any{"theme": "dark"}, + } + }) - // Execute prompt with session in context - _, err = p.Execute(ctx) - if err != nil { + // Execute prompt with state in context + if _, err := p.Execute(ctx); err != nil { t.Fatalf("Execute failed: %v", err) } - // Verify the session state was injected into the template + // Verify the state was injected into the template expected := "Hello Alice, your theme is dark" if capturedPrompt != expected { t.Errorf("Expected prompt %q, got %q", expected, capturedPrompt) } }) - t.Run("prompt works without session in context", func(t *testing.T) { + t.Run("prompt works without state in context", func(t *testing.T) { var capturedPrompt string testModel := DefineModel(r, "test/noSessionModel", &ModelOptions{ @@ -2421,7 +2411,7 @@ func TestSessionStateInjection(t *testing.T) { } }) - t.Run("session state and input variables can be used together", func(t *testing.T) { + t.Run("state and input variables can be used together", func(t *testing.T) { var capturedPrompt string testModel := DefineModel(r, "test/mixedModel", &ModelOptions{ @@ -2443,18 +2433,13 @@ func TestSessionStateInjection(t *testing.T) { }{}), ) - // Create session - ctx := context.Background() - sess, err := session.New(ctx, session.WithInitialState(UserState{ - Name: "Charlie", - })) - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - ctx = session.NewContext(ctx, sess) + // Attach state to the context the way a session does. + ctx := base.WithPromptState(context.Background(), func() any { + return map[string]any{"name": "Charlie"} + }) - // Execute with both session and input - _, err = p.Execute(ctx, WithInput(map[string]any{"question": "What is the weather?"})) + // Execute with both state and input + _, err := p.Execute(ctx, WithInput(map[string]any{"question": "What is the weather?"})) if err != nil { t.Fatalf("Execute failed: %v", err) } diff --git a/go/core/x/session/session.go b/go/core/x/session/session.go deleted file mode 100644 index 01db9e8562..0000000000 --- a/go/core/x/session/session.go +++ /dev/null @@ -1,358 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -// Package session provides experimental session management APIs for Genkit. -// -// A session encapsulates a stateful execution environment with strongly-typed -// state that can be persisted across requests. Sessions are useful for maintaining -// user preferences, conversation context, or any application state that needs -// to survive between interactions. -// -// APIs in this package are under active development and may change in any -// minor version release. Use with caution in production environments. -// -// When these APIs stabilize, they will be moved to the core package -// and these exports will be deprecated. -package session - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "sync" - - "github.com/google/uuid" -) - -// Session represents a stateful environment with typed state. -// The type parameter S defines the shape of the session state and must be -// JSON-serializable for persistence. -type Session[S any] struct { - id string - state S - store Store[S] - mu sync.RWMutex -} - -// Data is the serializable session state persisted by Store. -type Data[S any] struct { - ID string `json:"id"` - State S `json:"state,omitempty"` -} - -// Store persists session data to a backend (database, file, memory, etc). -// Implementations must be safe for concurrent use. -type Store[S any] interface { - // Get retrieves session data by ID. Returns nil if not found. - Get(ctx context.Context, sessionID string) (*Data[S], error) - // Save persists session data, creating or updating as needed. - Save(ctx context.Context, sessionID string, data *Data[S]) error -} - -// options holds configuration for creating a Session. -type options[S any] struct { - ID string - InitialState S - Store Store[S] - hasID bool - hasState bool - hasStore bool -} - -// Option configures a Session during creation. -type Option[S any] interface { - apply(*options[S]) error -} - -// apply implements Option for options, enabling composition. -func (o *options[S]) apply(opts *options[S]) error { - if o.hasID { - if opts.hasID { - return errors.New("cannot set ID more than once (WithID)") - } - opts.ID = o.ID - opts.hasID = true - } - - if o.hasState { - if opts.hasState { - return errors.New("cannot set initial state more than once (WithInitialState)") - } - opts.InitialState = o.InitialState - opts.hasState = true - } - - if o.hasStore { - if opts.hasStore { - return errors.New("cannot set store more than once (WithStore)") - } - opts.Store = o.Store - opts.hasStore = true - } - - return nil -} - -// WithID sets a custom session ID. If not provided, a UUID is generated. -func WithID[S any](id string) Option[S] { - return &options[S]{ID: id, hasID: true} -} - -// WithInitialState sets the initial state for a new session. -func WithInitialState[S any](state S) Option[S] { - return &options[S]{InitialState: state, hasState: true} -} - -// WithStore sets the persistence backend for the session. -// If not provided, the session is not persisted and exists only in memory. -func WithStore[S any](store Store[S]) Option[S] { - return &options[S]{Store: store, hasStore: true} -} - -// New creates a new session with the provided options. -// If a store is provided via [WithStore], the session is persisted immediately. -// If no store is provided, the session exists only in memory for the current -// request and can be propagated via context using [NewContext]. -// If no ID is provided, a new UUID is generated. -// If no initial state is provided, the session is created with an empty state. -func New[S any](ctx context.Context, opts ...Option[S]) (*Session[S], error) { - o := &options[S]{} - for _, opt := range opts { - if err := opt.apply(o); err != nil { - return nil, fmt.Errorf("session.New: %w", err) - } - } - - id := o.ID - if !o.hasID { - id = uuid.New().String() - } - - // Only persist if a store was explicitly provided - if o.hasStore { - data := &Data[S]{ - ID: id, - State: o.InitialState, - } - if err := o.Store.Save(ctx, id, data); err != nil { - return nil, fmt.Errorf("session.New: failed to persist initial state: %w", err) - } - } - - return &Session[S]{ - id: id, - state: o.InitialState, - store: o.Store, // nil if no store provided - }, nil -} - -// Load loads an existing session from the store. -// Returns an error if the session is not found or if loading fails. -func Load[S any](ctx context.Context, store Store[S], sessionID string) (*Session[S], error) { - data, err := store.Get(ctx, sessionID) - if err != nil { - return nil, fmt.Errorf("session.Load: %w", err) - } - if data == nil { - return nil, &NotFoundError{SessionID: sessionID} - } - - return &Session[S]{ - id: data.ID, - state: data.State, - store: store, - }, nil -} - -// ID returns the session's unique identifier. -func (s *Session[S]) ID() string { - return s.id -} - -// State returns the current session state. -// The returned value is a copy; modifications do not affect the session. -func (s *Session[S]) State() S { - s.mu.RLock() - defer s.mu.RUnlock() - return deepCopyState(s.state) -} - -// deepCopyState creates a deep copy of the state using JSON marshaling. -// Panics if serialization fails, as this indicates a programming error -// (the state type S must be JSON-serializable per the Session contract). -func deepCopyState[S any](state S) S { - bytes, err := json.Marshal(state) - if err != nil { - panic(fmt.Sprintf("session.State: failed to marshal state: %v", err)) - } - - var copied S - if err := json.Unmarshal(bytes, &copied); err != nil { - panic(fmt.Sprintf("session.State: failed to unmarshal state: %v", err)) - } - - return copied -} - -// UpdateState updates the session state and persists it to the store (if configured). -func (s *Session[S]) UpdateState(ctx context.Context, state S) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.state = state - - if s.store != nil { - data := &Data[S]{ - ID: s.id, - State: state, - } - if err := s.store.Save(ctx, s.id, data); err != nil { - return fmt.Errorf("session.UpdateState: %w", err) - } - } - - return nil -} - -// contextKey is a private type for context keys to avoid collisions. -type contextKey struct{} - -// sessionContextKey is the key used to store sessions in context. -var sessionContextKey = contextKey{} - -// sessionHolder wraps a session with its type erased for context storage. -type sessionHolder struct { - session any -} - -// NewContext returns a new context with the session attached. -func NewContext[S any](ctx context.Context, s *Session[S]) context.Context { - return context.WithValue(ctx, sessionContextKey, &sessionHolder{session: s}) -} - -// FromContext retrieves the current session from context. -// Returns nil if no session is in context or if the type doesn't match. -func FromContext[S any](ctx context.Context) *Session[S] { - holder, ok := ctx.Value(sessionContextKey).(*sessionHolder) - if !ok || holder == nil { - return nil - } - session, ok := holder.session.(*Session[S]) - if !ok { - return nil - } - return session -} - -// stateGetter is an internal interface for retrieving state without type parameters. -type stateGetter interface { - getState() any -} - -// getState implements stateGetter, returning the session state as any. -func (s *Session[S]) getState() any { - return s.State() -} - -// StateFromContext retrieves the current session state from context without -// requiring knowledge of the state type. This is useful for template rendering -// where the state type is not known at compile time. -// Returns nil if no session is in context. -func StateFromContext(ctx context.Context) any { - holder, ok := ctx.Value(sessionContextKey).(*sessionHolder) - if !ok || holder == nil { - return nil - } - if getter, ok := holder.session.(stateGetter); ok { - return getter.getState() - } - return nil -} - -// NotFoundError is returned when a session cannot be found in the store. -type NotFoundError struct { - SessionID string -} - -func (e *NotFoundError) Error() string { - return "session not found: " + e.SessionID -} - -// InMemoryStore is a thread-safe in-memory implementation of Store. -// Useful for testing or single-instance deployments where persistence is not required. -type InMemoryStore[S any] struct { - data map[string]*Data[S] - mu sync.RWMutex -} - -// NewInMemoryStore creates a new in-memory session store. -func NewInMemoryStore[S any]() *InMemoryStore[S] { - return &InMemoryStore[S]{ - data: make(map[string]*Data[S]), - } -} - -// Get retrieves session data by ID. -func (s *InMemoryStore[S]) Get(_ context.Context, sessionID string) (*Data[S], error) { - s.mu.RLock() - defer s.mu.RUnlock() - - data, exists := s.data[sessionID] - if !exists { - return nil, nil - } - - // Return a copy to prevent external modifications - copied, err := copyData(data) - if err != nil { - return nil, err - } - return copied, nil -} - -// Save persists session data. -func (s *InMemoryStore[S]) Save(_ context.Context, sessionID string, data *Data[S]) error { - s.mu.Lock() - defer s.mu.Unlock() - - // Store a copy to prevent external modifications - copied, err := copyData(data) - if err != nil { - return err - } - s.data[sessionID] = copied - return nil -} - -// copyData creates a deep copy of Data using JSON marshaling. -func copyData[S any](data *Data[S]) (*Data[S], error) { - if data == nil { - return nil, nil - } - - bytes, err := json.Marshal(data) - if err != nil { - return nil, fmt.Errorf("copy session data: marshal: %w", err) - } - - var copied Data[S] - if err := json.Unmarshal(bytes, &copied); err != nil { - return nil, fmt.Errorf("copy session data: unmarshal: %w", err) - } - - return &copied, nil -} diff --git a/go/core/x/session/session_test.go b/go/core/x/session/session_test.go deleted file mode 100644 index 55c25d76d1..0000000000 --- a/go/core/x/session/session_test.go +++ /dev/null @@ -1,782 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package session - -import ( - "context" - "errors" - "strings" - "sync" - "testing" -) - -// UserState is a test state type with various field types. -type UserState struct { - Name string `json:"name"` - Count int `json:"count"` - Preferences map[string]string `json:"preferences,omitempty"` -} - -func TestNew_DefaultID(t *testing.T) { - ctx := context.Background() - sess, err := New[UserState](ctx) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - if sess.ID() == "" { - t.Error("Expected session to have a generated ID") - } -} - -func TestNew_WithID(t *testing.T) { - ctx := context.Background() - customID := "my-custom-id" - sess, err := New(ctx, WithID[UserState](customID)) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - if sess.ID() != customID { - t.Errorf("Expected ID %q, got %q", customID, sess.ID()) - } -} - -func TestNew_WithInitialState(t *testing.T) { - ctx := context.Background() - initial := UserState{Name: "Alice", Count: 42} - sess, err := New(ctx, WithInitialState(initial)) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - got := sess.State() - if got.Name != initial.Name { - t.Errorf("Expected Name %q, got %q", initial.Name, got.Name) - } - if got.Count != initial.Count { - t.Errorf("Expected Count %d, got %d", initial.Count, got.Count) - } -} - -func TestNew_WithStore(t *testing.T) { - ctx := context.Background() - store := NewInMemoryStore[UserState]() - sess, err := New(ctx, WithStore(store)) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - if sess.store != store { - t.Error("Expected store to be set") - } -} - -func TestNew_MultipleOptions(t *testing.T) { - ctx := context.Background() - store := NewInMemoryStore[UserState]() - customID := "multi-option-id" - initial := UserState{Name: "Bob", Count: 100} - - sess, err := New(ctx, - WithID[UserState](customID), - WithInitialState(initial), - WithStore(store), - ) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - if sess.ID() != customID { - t.Errorf("Expected ID %q, got %q", customID, sess.ID()) - } - if sess.State().Name != initial.Name { - t.Errorf("Expected Name %q, got %q", initial.Name, sess.State().Name) - } - if sess.store != store { - t.Error("Expected store to be set") - } -} - -func TestNew_DuplicateID(t *testing.T) { - ctx := context.Background() - _, err := New(ctx, - WithID[UserState]("first"), - WithID[UserState]("second"), - ) - if err == nil { - t.Fatal("Expected error for duplicate WithID") - } - if !strings.Contains(err.Error(), "cannot set ID more than once") { - t.Errorf("Expected duplicate ID error, got: %v", err) - } -} - -func TestNew_DuplicateInitialState(t *testing.T) { - ctx := context.Background() - _, err := New(ctx, - WithInitialState(UserState{Name: "First"}), - WithInitialState(UserState{Name: "Second"}), - ) - if err == nil { - t.Fatal("Expected error for duplicate WithInitialState") - } - if !strings.Contains(err.Error(), "cannot set initial state more than once") { - t.Errorf("Expected duplicate state error, got: %v", err) - } -} - -func TestNew_DuplicateStore(t *testing.T) { - ctx := context.Background() - store1 := NewInMemoryStore[UserState]() - store2 := NewInMemoryStore[UserState]() - _, err := New(ctx, - WithStore(store1), - WithStore(store2), - ) - if err == nil { - t.Fatal("Expected error for duplicate WithStore") - } - if !strings.Contains(err.Error(), "cannot set store more than once") { - t.Errorf("Expected duplicate store error, got: %v", err) - } -} - -func TestSession_State(t *testing.T) { - ctx := context.Background() - initial := UserState{ - Name: "Alice", - Count: 10, - Preferences: map[string]string{"theme": "dark"}, - } - sess, err := New(ctx, WithInitialState(initial)) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - t.Run("returns correct values", func(t *testing.T) { - got := sess.State() - if got.Name != initial.Name { - t.Errorf("Expected Name %q, got %q", initial.Name, got.Name) - } - if got.Count != initial.Count { - t.Errorf("Expected Count %d, got %d", initial.Count, got.Count) - } - if got.Preferences["theme"] != "dark" { - t.Errorf("Expected theme %q, got %q", "dark", got.Preferences["theme"]) - } - }) - - t.Run("modifications to returned copy do not affect session", func(t *testing.T) { - // Get a copy of the state - copy1 := sess.State() - - // Modify the map in the returned copy - copy1.Preferences["theme"] = "light" - copy1.Preferences["newKey"] = "newValue" - copy1.Name = "Modified" - copy1.Count = 999 - - // Get another copy and verify the session's internal state is unchanged - copy2 := sess.State() - - if copy2.Name != "Alice" { - t.Errorf("Session state was mutated: expected Name %q, got %q", "Alice", copy2.Name) - } - if copy2.Count != 10 { - t.Errorf("Session state was mutated: expected Count %d, got %d", 10, copy2.Count) - } - if copy2.Preferences["theme"] != "dark" { - t.Errorf("Session state was mutated: expected theme %q, got %q", "dark", copy2.Preferences["theme"]) - } - if _, exists := copy2.Preferences["newKey"]; exists { - t.Errorf("Session state was mutated: unexpected key 'newKey' in Preferences") - } - }) -} - -func TestSession_UpdateState_NoStore(t *testing.T) { - ctx := context.Background() - sess, err := New(ctx, WithInitialState(UserState{Name: "Alice"})) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - // Verify no store is set when not provided - if sess.store != nil { - t.Fatal("Expected no store when not provided") - } - - newState := UserState{Name: "Bob", Count: 5} - if err := sess.UpdateState(ctx, newState); err != nil { - t.Fatalf("UpdateState failed: %v", err) - } - - // State should still be updated in memory - got := sess.State() - if got.Name != newState.Name { - t.Errorf("Expected Name %q, got %q", newState.Name, got.Name) - } - if got.Count != newState.Count { - t.Errorf("Expected Count %d, got %d", newState.Count, got.Count) - } -} - -func TestSession_UpdateState_WithStore(t *testing.T) { - ctx := context.Background() - store := NewInMemoryStore[UserState]() - sess, err := New(ctx, - WithID[UserState]("test-session"), - WithInitialState(UserState{Name: "Alice"}), - WithStore(store), - ) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - newState := UserState{Name: "Bob", Count: 5} - if err := sess.UpdateState(ctx, newState); err != nil { - t.Fatalf("UpdateState failed: %v", err) - } - - // Verify state is updated in session - got := sess.State() - if got.Name != newState.Name { - t.Errorf("Expected Name %q, got %q", newState.Name, got.Name) - } - - // Verify state is persisted in store - data, err := store.Get(ctx, "test-session") - if err != nil { - t.Fatalf("Store.Get failed: %v", err) - } - if data == nil { - t.Fatal("Expected data in store, got nil") - } - if data.State.Name != newState.Name { - t.Errorf("Store: expected Name %q, got %q", newState.Name, data.State.Name) - } -} - -func TestLoad_Success(t *testing.T) { - store := NewInMemoryStore[UserState]() - ctx := context.Background() - - // Save some data - data := &Data[UserState]{ - ID: "existing-session", - State: UserState{Name: "Charlie", Count: 99}, - } - if err := store.Save(ctx, data.ID, data); err != nil { - t.Fatalf("Store.Save failed: %v", err) - } - - // Load the session - loaded, err := Load(ctx, store, "existing-session") - if err != nil { - t.Fatalf("Load failed: %v", err) - } - - if loaded.ID() != "existing-session" { - t.Errorf("Expected ID %q, got %q", "existing-session", loaded.ID()) - } - if loaded.State().Name != "Charlie" { - t.Errorf("Expected Name %q, got %q", "Charlie", loaded.State().Name) - } - if loaded.State().Count != 99 { - t.Errorf("Expected Count %d, got %d", 99, loaded.State().Count) - } -} - -func TestLoad_NotFound(t *testing.T) { - store := NewInMemoryStore[UserState]() - ctx := context.Background() - - _, err := Load(ctx, store, "non-existent") - if err == nil { - t.Fatal("Expected error for non-existent session") - } - - var notFoundErr *NotFoundError - if !errors.As(err, ¬FoundErr) { - t.Errorf("Expected NotFoundError, got %T: %v", err, err) - } - if notFoundErr.SessionID != "non-existent" { - t.Errorf("Expected SessionID %q, got %q", "non-existent", notFoundErr.SessionID) - } -} - -func TestNewContext_FromContext(t *testing.T) { - ctx := context.Background() - sess, err := New(ctx, - WithID[UserState]("ctx-test"), - WithInitialState(UserState{Name: "Diana"}), - ) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - // Attach session to context - ctx = NewContext(ctx, sess) - - // Retrieve from context - retrieved := FromContext[UserState](ctx) - if retrieved == nil { - t.Fatal("Expected session from context, got nil") - } - if retrieved.ID() != "ctx-test" { - t.Errorf("Expected ID %q, got %q", "ctx-test", retrieved.ID()) - } - if retrieved.State().Name != "Diana" { - t.Errorf("Expected Name %q, got %q", "Diana", retrieved.State().Name) - } -} - -func TestStateFromContext(t *testing.T) { - t.Run("returns state when session exists", func(t *testing.T) { - ctx := context.Background() - initial := UserState{ - Name: "Alice", - Count: 42, - Preferences: map[string]string{"theme": "dark"}, - } - sess, err := New(ctx, WithInitialState(initial)) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - ctx = NewContext(ctx, sess) - - state := StateFromContext(ctx) - if state == nil { - t.Fatal("Expected state from context, got nil") - } - - // StateFromContext returns the state as any, so we need to type assert - userState, ok := state.(UserState) - if !ok { - t.Fatalf("Expected UserState, got %T", state) - } - - if userState.Name != "Alice" { - t.Errorf("Expected Name %q, got %q", "Alice", userState.Name) - } - if userState.Count != 42 { - t.Errorf("Expected Count %d, got %d", 42, userState.Count) - } - if userState.Preferences["theme"] != "dark" { - t.Errorf("Expected theme %q, got %q", "dark", userState.Preferences["theme"]) - } - }) - - t.Run("returns nil when no session in context", func(t *testing.T) { - ctx := context.Background() - state := StateFromContext(ctx) - if state != nil { - t.Errorf("Expected nil for empty context, got %v", state) - } - }) - - t.Run("returns deep copy that cannot mutate session", func(t *testing.T) { - ctx := context.Background() - initial := UserState{ - Name: "Bob", - Preferences: map[string]string{"lang": "en"}, - } - sess, err := New(ctx, WithInitialState(initial)) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - ctx = NewContext(ctx, sess) - - // Get state via StateFromContext - state := StateFromContext(ctx) - userState := state.(UserState) - - // Modify the returned state - userState.Name = "Modified" - userState.Preferences["lang"] = "fr" - - // Verify the session's internal state is unchanged - originalState := sess.State() - if originalState.Name != "Bob" { - t.Errorf("Session state was mutated: expected Name %q, got %q", "Bob", originalState.Name) - } - if originalState.Preferences["lang"] != "en" { - t.Errorf("Session state was mutated: expected lang %q, got %q", "en", originalState.Preferences["lang"]) - } - }) -} - -func TestFromContext_NoSession(t *testing.T) { - ctx := context.Background() - - retrieved := FromContext[UserState](ctx) - if retrieved != nil { - t.Errorf("Expected nil for empty context, got %v", retrieved) - } -} - -func TestFromContext_WrongType(t *testing.T) { - ctx := context.Background() - // Create session with one type - type OtherState struct { - Value string - } - sess, err := New(ctx, WithInitialState(OtherState{Value: "test"})) - if err != nil { - t.Fatalf("New failed: %v", err) - } - ctx = NewContext(ctx, sess) - - // Try to retrieve with different type - retrieved := FromContext[UserState](ctx) - if retrieved != nil { - t.Errorf("Expected nil for wrong type, got %v", retrieved) - } -} - -func TestInMemoryStore_GetSave(t *testing.T) { - store := NewInMemoryStore[UserState]() - ctx := context.Background() - - // Initially empty - data, err := store.Get(ctx, "test-id") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if data != nil { - t.Errorf("Expected nil for non-existent key, got %v", data) - } - - // Save data - original := &Data[UserState]{ - ID: "test-id", - State: UserState{Name: "Eve", Count: 7}, - } - if err := store.Save(ctx, "test-id", original); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Retrieve data - retrieved, err := store.Get(ctx, "test-id") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved == nil { - t.Fatal("Expected data, got nil") - } - if retrieved.ID != original.ID { - t.Errorf("Expected ID %q, got %q", original.ID, retrieved.ID) - } - if retrieved.State.Name != original.State.Name { - t.Errorf("Expected Name %q, got %q", original.State.Name, retrieved.State.Name) - } -} - -func TestInMemoryStore_Isolation(t *testing.T) { - store := NewInMemoryStore[UserState]() - ctx := context.Background() - - // Save data - original := &Data[UserState]{ - ID: "isolation-test", - State: UserState{Name: "Frank", Count: 1}, - } - if err := store.Save(ctx, "isolation-test", original); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Modify original after save - original.State.Name = "Modified" - - // Retrieved data should not be affected - retrieved, err := store.Get(ctx, "isolation-test") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved.State.Name != "Frank" { - t.Errorf("Expected Name %q (isolation), got %q", "Frank", retrieved.State.Name) - } - - // Modify retrieved data - retrieved.State.Name = "Also Modified" - - // Get again - should still be original - retrieved2, err := store.Get(ctx, "isolation-test") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved2.State.Name != "Frank" { - t.Errorf("Expected Name %q (isolation), got %q", "Frank", retrieved2.State.Name) - } -} - -func TestInMemoryStore_Overwrite(t *testing.T) { - store := NewInMemoryStore[UserState]() - ctx := context.Background() - - // Save initial data - initial := &Data[UserState]{ - ID: "overwrite-test", - State: UserState{Name: "Grace", Count: 1}, - } - if err := store.Save(ctx, "overwrite-test", initial); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Overwrite with new data - updated := &Data[UserState]{ - ID: "overwrite-test", - State: UserState{Name: "Grace Updated", Count: 2}, - } - if err := store.Save(ctx, "overwrite-test", updated); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Retrieve and verify - retrieved, err := store.Get(ctx, "overwrite-test") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved.State.Name != "Grace Updated" { - t.Errorf("Expected Name %q, got %q", "Grace Updated", retrieved.State.Name) - } - if retrieved.State.Count != 2 { - t.Errorf("Expected Count %d, got %d", 2, retrieved.State.Count) - } -} - -func TestSession_ConcurrentAccess(t *testing.T) { - ctx := context.Background() - store := NewInMemoryStore[UserState]() - sess, err := New(ctx, - WithID[UserState]("concurrent-test"), - WithInitialState(UserState{Name: "Initial", Count: 0}), - WithStore(store), - ) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - const numGoroutines = 10 - const numUpdates = 100 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func(id int) { - defer wg.Done() - for j := 0; j < numUpdates; j++ { - // Read state - _ = sess.State() - - // Update state - _ = sess.UpdateState(ctx, UserState{ - Name: "Goroutine", - Count: id*numUpdates + j, - }) - } - }(i) - } - - wg.Wait() - - // Verify no data corruption - state := sess.State() - if state.Name != "Goroutine" { - t.Errorf("Expected Name %q, got %q", "Goroutine", state.Name) - } -} - -func TestInMemoryStore_ConcurrentAccess(t *testing.T) { - store := NewInMemoryStore[UserState]() - ctx := context.Background() - - const numGoroutines = 10 - const numOperations = 100 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func(id int) { - defer wg.Done() - key := "shared-key" - for j := 0; j < numOperations; j++ { - // Save - data := &Data[UserState]{ - ID: key, - State: UserState{Name: "Concurrent", Count: id*numOperations + j}, - } - _ = store.Save(ctx, key, data) - - // Get - _, _ = store.Get(ctx, key) - } - }(i) - } - - wg.Wait() - - // Verify we can still read - data, err := store.Get(ctx, "shared-key") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if data == nil { - t.Fatal("Expected data, got nil") - } -} - -func TestNotFoundError(t *testing.T) { - err := &NotFoundError{SessionID: "test-123"} - - expected := "session not found: test-123" - if err.Error() != expected { - t.Errorf("Expected error message %q, got %q", expected, err.Error()) - } -} - -func TestSession_ZeroState(t *testing.T) { - ctx := context.Background() - // Create session without initial state - sess, err := New[UserState](ctx) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - state := sess.State() - if state.Name != "" { - t.Errorf("Expected empty Name, got %q", state.Name) - } - if state.Count != 0 { - t.Errorf("Expected zero Count, got %d", state.Count) - } - if state.Preferences != nil { - t.Errorf("Expected nil Preferences, got %v", state.Preferences) - } -} - -func TestSession_ComplexState(t *testing.T) { - ctx := context.Background() - type NestedState struct { - Inner struct { - Value string `json:"value"` - } `json:"inner"` - List []int `json:"list"` - } - - store := NewInMemoryStore[NestedState]() - initial := NestedState{ - List: []int{1, 2, 3}, - } - initial.Inner.Value = "nested" - - sess, err := New(ctx, - WithID[NestedState]("complex-state"), - WithInitialState(initial), - WithStore(store), - ) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - // Update with nested modifications - newState := NestedState{ - List: []int{4, 5, 6, 7}, - } - newState.Inner.Value = "updated nested" - - if err := sess.UpdateState(ctx, newState); err != nil { - t.Fatalf("UpdateState failed: %v", err) - } - - // Verify nested state is correct - got := sess.State() - if got.Inner.Value != "updated nested" { - t.Errorf("Expected Inner.Value %q, got %q", "updated nested", got.Inner.Value) - } - if len(got.List) != 4 { - t.Errorf("Expected List length %d, got %d", 4, len(got.List)) - } - - // Verify persistence - data, err := store.Get(ctx, "complex-state") - if err != nil { - t.Fatalf("Store.Get failed: %v", err) - } - if data.State.Inner.Value != "updated nested" { - t.Errorf("Store: expected Inner.Value %q, got %q", "updated nested", data.State.Inner.Value) - } -} - -// mockFailingStore is a store that fails on Save for testing error handling. -type mockFailingStore[S any] struct { - saveErr error -} - -func (s *mockFailingStore[S]) Get(_ context.Context, _ string) (*Data[S], error) { - return nil, nil -} - -func (s *mockFailingStore[S]) Save(_ context.Context, _ string, _ *Data[S]) error { - return s.saveErr -} -func TestNew_StoreError(t *testing.T) { - ctx := context.Background() - expectedErr := errors.New("store failure") - store := &mockFailingStore[UserState]{saveErr: expectedErr} - _, err := New(ctx, - WithID[UserState]("error-test"), - WithStore(store), - ) - if err == nil { - t.Fatal("Expected error from failing store") - } - if !strings.Contains(err.Error(), "failed to persist initial state") { - t.Errorf("Expected persist error, got: %v", err) - } - if !errors.Is(err, expectedErr) { - t.Errorf("Expected wrapped error %v, got %v", expectedErr, err) - } -} - -func TestSession_UpdateState_StoreError(t *testing.T) { - ctx := context.Background() - store := NewInMemoryStore[UserState]() - sess, err := New(ctx, - WithID[UserState]("error-test"), - WithStore(store), - ) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - expectedErr := errors.New("store failure") - sess.store = &mockFailingStore[UserState]{saveErr: expectedErr} - - err = sess.UpdateState(ctx, UserState{Name: "Test"}) - if err == nil { - t.Fatal("Expected error from failing store") - } - if !errors.Is(err, expectedErr) { - t.Errorf("Expected error wrapping %v, got %v", expectedErr, err) - } -} diff --git a/go/internal/base/prompt_state.go b/go/internal/base/prompt_state.go new file mode 100644 index 0000000000..32c9918251 --- /dev/null +++ b/go/internal/base/prompt_state.go @@ -0,0 +1,44 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package base + +import "context" + +// promptStateKey holds a getter for the type-erased state that prompt rendering +// injects into templates as {{@state}}. The getter indirection lets a +// higher-level package publish its session state for prompt rendering without a +// circular import: the session lives in a package that imports go/ai, so go/ai +// cannot import it back to read the state directly. +var promptStateKey = NewContextKey[func() any]() + +// WithPromptState returns ctx carrying a getter for the state exposed to prompt +// templates via {{@state}}. getState is evaluated lazily at render time, so it +// observes the latest state rather than a snapshot taken when the context was +// built. A nil getState detaches any state previously attached. +func WithPromptState(ctx context.Context, getState func() any) context.Context { + return promptStateKey.NewContext(ctx, getState) +} + +// PromptStateFromContext returns the state attached by [WithPromptState], or nil +// if none is attached. The getter is invoked on each call. +func PromptStateFromContext(ctx context.Context) any { + getState := promptStateKey.FromContext(ctx) + if getState == nil { + return nil + } + return getState() +} diff --git a/go/plugins/firebase/x/option.go b/go/plugins/firebase/x/option.go index 9de3443d60..03d0889008 100644 --- a/go/plugins/firebase/x/option.go +++ b/go/plugins/firebase/x/option.go @@ -56,11 +56,6 @@ func (o *firestoreOptions) applyStreamManager(opts *streamManagerOptions) error return o.applyFirestore(&opts.firestoreOptions) } -// applySessionStore implements SessionStoreOption for firestoreOptions. -func (o *firestoreOptions) applySessionStore(opts *sessionStoreOptions) error { - return o.applyFirestore(&opts.firestoreOptions) -} - // WithCollection sets the Firestore collection name where documents are stored. // This option is required for all Firestore-based services. func WithCollection(collection string) *firestoreOptions { diff --git a/go/plugins/firebase/x/session_store.go b/go/plugins/firebase/x/session_store.go deleted file mode 100644 index 57ce322f4b..0000000000 --- a/go/plugins/firebase/x/session_store.go +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package x - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "time" - - "cloud.google.com/go/firestore" - "github.com/firebase/genkit/go/core/x/session" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/firebase" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// SessionStoreOption configures a FirestoreSessionStore. -// Implemented by firestoreOptions (WithCollection, WithTTL). -type SessionStoreOption interface { - applySessionStore(*sessionStoreOptions) error -} - -// sessionStoreOptions holds configuration for FirestoreSessionStore. -type sessionStoreOptions struct { - firestoreOptions -} - -// applySessionStore implements SessionStoreOption for sessionStoreOptions. -func (o *sessionStoreOptions) applySessionStore(opts *sessionStoreOptions) error { - return o.firestoreOptions.applyFirestore(&opts.firestoreOptions) -} - -// FirestoreSessionStore implements [session.Store[S]] using Firestore as the backend. -// Session state is persisted in Firestore documents, allowing sessions to survive -// server restarts and be accessible across multiple instances. -type FirestoreSessionStore[S any] struct { - client *firestore.Client - collection string - ttl time.Duration -} - -// sessionDocument represents the structure of a session document in Firestore. -type sessionDocument struct { - State json.RawMessage `firestore:"state"` - CreatedAt time.Time `firestore:"createdAt"` - UpdatedAt time.Time `firestore:"updatedAt"` - ExpiresAt *time.Time `firestore:"expiresAt,omitempty"` -} - -// NewFirestoreSessionStore creates a Firestore-backed session store. -// Requires the Firebase plugin to be initialized in the Genkit instance. -func NewFirestoreSessionStore[S any](ctx context.Context, g *genkit.Genkit, opts ...SessionStoreOption) (*FirestoreSessionStore[S], error) { - storeOpts := &sessionStoreOptions{} - for _, opt := range opts { - if err := opt.applySessionStore(storeOpts); err != nil { - return nil, fmt.Errorf("firebase.NewFirestoreSessionStore: error applying options: %w", err) - } - } - if storeOpts.Collection == "" { - return nil, errors.New("firebase.NewFirestoreSessionStore: Collection name is required.\n" + - " Specify the Firestore collection where session documents will be stored:\n" + - " firebase.NewFirestoreSessionStore[MyState](ctx, g, firebase.WithCollection(\"genkit-sessions\"))") - } - if storeOpts.TTL == 0 { - storeOpts.TTL = DefaultTTL - } - - plugin := genkit.LookupPlugin(g, "firebase") - if plugin == nil { - return nil, errors.New("firebase.NewFirestoreSessionStore: Firebase plugin not found.\n" + - " Pass the Firebase plugin to genkit.Init():\n" + - " g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: \"your-project\"}))") - } - f, ok := plugin.(*firebase.Firebase) - if !ok { - return nil, fmt.Errorf("firebase.NewFirestoreSessionStore: unexpected plugin type %T", plugin) - } - - client, err := f.Firestore(ctx) - if err != nil { - return nil, fmt.Errorf("firebase.NewFirestoreSessionStore: %w", err) - } - - return &FirestoreSessionStore[S]{ - client: client, - collection: storeOpts.Collection, - ttl: storeOpts.TTL, - }, nil -} - -// Get retrieves session data by ID from Firestore. -// Returns nil if the session does not exist. -func (s *FirestoreSessionStore[S]) Get(ctx context.Context, sessionID string) (*session.Data[S], error) { - docRef := s.client.Collection(s.collection).Doc(sessionID) - - snapshot, err := docRef.Get(ctx) - if err != nil { - if status.Code(err) == codes.NotFound { - return nil, nil - } - return nil, fmt.Errorf("firebase.FirestoreSessionStore.Get: %w", err) - } - if !snapshot.Exists() { - return nil, nil - } - - var doc sessionDocument - if err := snapshot.DataTo(&doc); err != nil { - return nil, fmt.Errorf("firebase.FirestoreSessionStore.Get: failed to parse document: %w", err) - } - - var state S - if len(doc.State) > 0 { - if err := json.Unmarshal(doc.State, &state); err != nil { - return nil, fmt.Errorf("firebase.FirestoreSessionStore.Get: failed to unmarshal state: %w", err) - } - } - - return &session.Data[S]{ - ID: sessionID, - State: state, - }, nil -} - -// Save persists session data to Firestore, creating or updating as needed. -// CreatedAt is only set when the document is first created; subsequent saves -// only update UpdatedAt and ExpiresAt. -func (s *FirestoreSessionStore[S]) Save(ctx context.Context, sessionID string, data *session.Data[S]) error { - docRef := s.client.Collection(s.collection).Doc(sessionID) - - stateJSON, err := json.Marshal(data.State) - if err != nil { - return fmt.Errorf("firebase.FirestoreSessionStore.Save: failed to marshal state: %w", err) - } - - now := time.Now() - expiresAt := now.Add(s.ttl) - - err = s.client.RunTransaction(ctx, func(ctx context.Context, tx *firestore.Transaction) error { - snapshot, err := tx.Get(docRef) - if err != nil && status.Code(err) != codes.NotFound { - return err - } - - if !snapshot.Exists() { - // Document doesn't exist - create it with CreatedAt - return tx.Create(docRef, sessionDocument{ - State: stateJSON, - CreatedAt: now, - UpdatedAt: now, - ExpiresAt: &expiresAt, - }) - } - - // Document exists - update without modifying CreatedAt - return tx.Update(docRef, []firestore.Update{ - {Path: "state", Value: stateJSON}, - {Path: "updatedAt", Value: now}, - {Path: "expiresAt", Value: &expiresAt}, - }) - }) - if err != nil { - return fmt.Errorf("firebase.FirestoreSessionStore.Save: %w", err) - } - - return nil -} diff --git a/go/plugins/firebase/x/session_store_test.go b/go/plugins/firebase/x/session_store_test.go deleted file mode 100644 index 7b7c998e45..0000000000 --- a/go/plugins/firebase/x/session_store_test.go +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package x - -import ( - "context" - "flag" - "testing" - "time" - - "cloud.google.com/go/firestore" - "github.com/firebase/genkit/go/core/x/session" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/firebase" - "google.golang.org/api/iterator" -) - -var ( - testSessionProjectID = flag.String("test-session-project-id", "", "GCP Project ID to use for session store tests") - testSessionCollection = flag.String("test-session-collection", "genkit-sessions", "Firestore collection to use for session store tests") -) - -/* - * Pre-requisites to run this test: - * - * 1. **Option A - Use Firestore Emulator (Recommended for local development):** - * Start the Firestore emulator: - * ```bash - * export FIRESTORE_EMULATOR_HOST=127.0.0.1:8080 - * gcloud emulators firestore start --host-port=127.0.0.1:8080 - * ``` - * - * 2. **Option B - Use a Real Firestore Database:** - * - Set up a Firebase project with Firestore enabled - * - Authenticate using: - * ```bash - * gcloud auth application-default login - * ``` - * - * 3. **Running the Test:** - * ```bash - * go test -test-session-project-id= -test-session-collection=genkit-sessions - * ``` - */ - -// TestState is a test state type with various field types. -type TestState struct { - Name string `json:"name"` - Count int `json:"count"` - Preferences map[string]string `json:"preferences,omitempty"` -} - -func skipIfNoFirestoreSession(t *testing.T) { - if *testSessionProjectID == "" { - t.Skip("Skipping test: -test-session-project-id flag not provided") - } -} - -func setupTestSessionStore(t *testing.T) (*FirestoreSessionStore[TestState], *firestore.Client, func()) { - skipIfNoFirestoreSession(t) - - ctx := context.Background() - g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) - - f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) - client, err := f.Firestore(ctx) - if err != nil { - t.Fatalf("Failed to get Firestore client: %v", err) - } - - store, err := NewFirestoreSessionStore[TestState](ctx, g, - WithCollection(*testSessionCollection), - ) - if err != nil { - t.Fatalf("Failed to create session store: %v", err) - } - - cleanup := func() { - deleteSessionCollection(ctx, client, *testSessionCollection, t) - } - - return store, client, cleanup -} - -func deleteSessionCollection(ctx context.Context, client *firestore.Client, collectionName string, t *testing.T) { - iter := client.Collection(collectionName).Documents(ctx) - for { - doc, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - t.Logf("Failed to iterate documents for deletion: %v", err) - return - } - _, err = doc.Ref.Delete(ctx) - if err != nil { - t.Logf("Failed to delete document %s: %v", doc.Ref.ID, err) - } - } -} - -func TestNewFirestoreSessionStore_MissingCollection(t *testing.T) { - skipIfNoFirestoreSession(t) - - ctx := context.Background() - g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) - - _, err := NewFirestoreSessionStore[TestState](ctx, g) - if err == nil { - t.Fatal("Expected error when collection is missing") - } -} - -func TestFirestoreSessionStore_SaveAndGet(t *testing.T) { - store, _, cleanup := setupTestSessionStore(t) - defer cleanup() - - ctx := context.Background() - sessionID := "test-session-save-get" - - // Initially empty - data, err := store.Get(ctx, sessionID) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if data != nil { - t.Errorf("Expected nil for non-existent session, got %v", data) - } - - // Save data - original := &session.Data[TestState]{ - ID: sessionID, - State: TestState{ - Name: "Alice", - Count: 42, - Preferences: map[string]string{"theme": "dark"}, - }, - } - if err := store.Save(ctx, sessionID, original); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Retrieve data - retrieved, err := store.Get(ctx, sessionID) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved == nil { - t.Fatal("Expected data, got nil") - } - if retrieved.ID != sessionID { - t.Errorf("Expected ID %q, got %q", sessionID, retrieved.ID) - } - if retrieved.State.Name != original.State.Name { - t.Errorf("Expected Name %q, got %q", original.State.Name, retrieved.State.Name) - } - if retrieved.State.Count != original.State.Count { - t.Errorf("Expected Count %d, got %d", original.State.Count, retrieved.State.Count) - } - if retrieved.State.Preferences["theme"] != "dark" { - t.Errorf("Expected theme %q, got %q", "dark", retrieved.State.Preferences["theme"]) - } -} - -func TestFirestoreSessionStore_Overwrite(t *testing.T) { - store, client, cleanup := setupTestSessionStore(t) - defer cleanup() - - ctx := context.Background() - sessionID := "test-session-overwrite" - - // Save initial data - initial := &session.Data[TestState]{ - ID: sessionID, - State: TestState{Name: "Alice", Count: 1}, - } - if err := store.Save(ctx, sessionID, initial); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Get the initial document to capture CreatedAt and UpdatedAt - snapshot1, err := client.Collection(*testSessionCollection).Doc(sessionID).Get(ctx) - if err != nil { - t.Fatalf("Failed to get initial document: %v", err) - } - initialData := snapshot1.Data() - initialCreatedAt, ok := initialData["createdAt"].(time.Time) - if !ok { - t.Fatal("Expected createdAt to be a timestamp") - } - initialUpdatedAt, ok := initialData["updatedAt"].(time.Time) - if !ok { - t.Fatal("Expected updatedAt to be a timestamp") - } - - // Wait a moment to ensure timestamp difference is detectable - time.Sleep(10 * time.Millisecond) - - // Overwrite with new data - updated := &session.Data[TestState]{ - ID: sessionID, - State: TestState{Name: "Alice Updated", Count: 2}, - } - if err := store.Save(ctx, sessionID, updated); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Get the updated document to verify timestamps - snapshot2, err := client.Collection(*testSessionCollection).Doc(sessionID).Get(ctx) - if err != nil { - t.Fatalf("Failed to get updated document: %v", err) - } - updatedData := snapshot2.Data() - updatedCreatedAt, ok := updatedData["createdAt"].(time.Time) - if !ok { - t.Fatal("Expected createdAt to be a timestamp after update") - } - updatedUpdatedAt, ok := updatedData["updatedAt"].(time.Time) - if !ok { - t.Fatal("Expected updatedAt to be a timestamp after update") - } - - // Verify CreatedAt is preserved (not modified during overwrite) - if !updatedCreatedAt.Equal(initialCreatedAt) { - t.Errorf("CreatedAt was modified during overwrite: initial=%v, after=%v", initialCreatedAt, updatedCreatedAt) - } - - // Verify UpdatedAt is modified (should be later than initial) - if !updatedUpdatedAt.After(initialUpdatedAt) { - t.Errorf("UpdatedAt should be later after overwrite: initial=%v, after=%v", initialUpdatedAt, updatedUpdatedAt) - } - - // Retrieve and verify state data - retrieved, err := store.Get(ctx, sessionID) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved.State.Name != "Alice Updated" { - t.Errorf("Expected Name %q, got %q", "Alice Updated", retrieved.State.Name) - } - if retrieved.State.Count != 2 { - t.Errorf("Expected Count %d, got %d", 2, retrieved.State.Count) - } -} - -func TestFirestoreSessionStore_ExpiresAt(t *testing.T) { - store, client, cleanup := setupTestSessionStore(t) - defer cleanup() - - ctx := context.Background() - sessionID := "test-session-expires" - - data := &session.Data[TestState]{ - ID: sessionID, - State: TestState{Name: "ExpiresTest"}, - } - if err := store.Save(ctx, sessionID, data); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Verify expiresAt is set in Firestore - snapshot, err := client.Collection(*testSessionCollection).Doc(sessionID).Get(ctx) - if err != nil { - t.Fatalf("Failed to get document: %v", err) - } - - docData := snapshot.Data() - if docData["expiresAt"] == nil { - t.Error("Expected expiresAt to be set") - } -} - -func TestFirestoreSessionStore_WithTTL(t *testing.T) { - skipIfNoFirestoreSession(t) - - ctx := context.Background() - g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) - - f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) - client, err := f.Firestore(ctx) - if err != nil { - t.Fatalf("Failed to get Firestore client: %v", err) - } - defer deleteSessionCollection(ctx, client, *testSessionCollection, t) - - customTTL := 1 * time.Hour - store, err := NewFirestoreSessionStore[TestState](ctx, g, - WithCollection(*testSessionCollection), - WithTTL(customTTL), - ) - if err != nil { - t.Fatalf("Failed to create session store: %v", err) - } - - if store.ttl != customTTL { - t.Errorf("Expected TTL %v, got %v", customTTL, store.ttl) - } -} - -func TestFirestoreSessionStore_EmptyState(t *testing.T) { - store, _, cleanup := setupTestSessionStore(t) - defer cleanup() - - ctx := context.Background() - sessionID := "test-session-empty" - - // Save session with zero-value state - data := &session.Data[TestState]{ - ID: sessionID, - State: TestState{}, - } - if err := store.Save(ctx, sessionID, data); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Retrieve and verify - retrieved, err := store.Get(ctx, sessionID) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved == nil { - t.Fatal("Expected data, got nil") - } - if retrieved.State.Name != "" { - t.Errorf("Expected empty Name, got %q", retrieved.State.Name) - } - if retrieved.State.Count != 0 { - t.Errorf("Expected zero Count, got %d", retrieved.State.Count) - } -} - -func TestFirestoreSessionStore_ComplexState(t *testing.T) { - skipIfNoFirestoreSession(t) - - ctx := context.Background() - g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) - - f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) - client, err := f.Firestore(ctx) - if err != nil { - t.Fatalf("Failed to get Firestore client: %v", err) - } - defer deleteSessionCollection(ctx, client, *testSessionCollection, t) - - type NestedState struct { - Inner struct { - Value string `json:"value"` - } `json:"inner"` - List []int `json:"list"` - } - - store, err := NewFirestoreSessionStore[NestedState](ctx, g, - WithCollection(*testSessionCollection), - ) - if err != nil { - t.Fatalf("Failed to create session store: %v", err) - } - - sessionID := "test-session-complex" - - // Save complex state - state := NestedState{ - List: []int{1, 2, 3, 4, 5}, - } - state.Inner.Value = "nested value" - - data := &session.Data[NestedState]{ - ID: sessionID, - State: state, - } - if err := store.Save(ctx, sessionID, data); err != nil { - t.Fatalf("Save failed: %v", err) - } - - // Retrieve and verify - retrieved, err := store.Get(ctx, sessionID) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if retrieved == nil { - t.Fatal("Expected data, got nil") - } - if retrieved.State.Inner.Value != "nested value" { - t.Errorf("Expected Inner.Value %q, got %q", "nested value", retrieved.State.Inner.Value) - } - if len(retrieved.State.List) != 5 { - t.Errorf("Expected List length %d, got %d", 5, len(retrieved.State.List)) - } -} - -func TestFirestoreSessionStore_IntegrationWithSession(t *testing.T) { - store, _, cleanup := setupTestSessionStore(t) - defer cleanup() - - ctx := context.Background() - - // Create a session with the Firestore store - sess, err := session.New(ctx, - session.WithID[TestState]("integration-test"), - session.WithInitialState(TestState{Name: "Integration", Count: 0}), - session.WithStore(store), - ) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - // Update state (should persist to Firestore) - if err := sess.UpdateState(ctx, TestState{Name: "Updated", Count: 10}); err != nil { - t.Fatalf("UpdateState failed: %v", err) - } - - // Load session from store - loaded, err := session.Load(ctx, store, "integration-test") - if err != nil { - t.Fatalf("Load failed: %v", err) - } - - if loaded.State().Name != "Updated" { - t.Errorf("Expected Name %q, got %q", "Updated", loaded.State().Name) - } - if loaded.State().Count != 10 { - t.Errorf("Expected Count %d, got %d", 10, loaded.State().Count) - } -} diff --git a/go/samples/session/main.go b/go/samples/session/main.go deleted file mode 100644 index c1f6d7b0b9..0000000000 --- a/go/samples/session/main.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This sample demonstrates how to use sessions to maintain state across -// multiple requests. It implements a shopping cart where items persist -// between calls using the session API. -// -// To run: -// -// go run . -// -// In another terminal, test (items persist across requests): -// -// curl -X POST http://localhost:8080/manageCart \ -// -H "Content-Type: application/json" \ -// -d '{"data": "Add apples and bananas to my cart"}' -// -// curl -X POST http://localhost:8080/manageCart \ -// -H "Content-Type: application/json" \ -// -d '{"data": "What is in my cart?"}' -package main - -import ( - "context" - "fmt" - "log" - "net/http" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core/x/session" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "github.com/firebase/genkit/go/plugins/server" - "google.golang.org/genai" -) - -// CartState holds the shopping cart items. -type CartState struct { - Items []string `json:"items"` -} - -func main() { - ctx := context.Background() - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - - // Create in-memory store (shared across requests). - store := session.NewInMemoryStore[CartState]() - - // Fixed session ID for simplicity. - const sessionID = "shopping-session" - - // Define addToCart tool - adds an item to the cart stored in session state. - addToCartTool := genkit.DefineTool(g, "addToCart", - "Adds items to the shopping cart", - func(ctx *ai.ToolContext, input struct{ Items []string }) ([]string, error) { - sess := session.FromContext[CartState](ctx.Context) - if sess == nil { - return nil, fmt.Errorf("no session in context") - } - state := sess.State() - state.Items = append(state.Items, input.Items...) - if err := sess.UpdateState(ctx.Context, state); err != nil { - return nil, err - } - return state.Items, nil - }, - ) - - // Define getCart tool - returns all items currently in the cart. - getCartTool := genkit.DefineTool(g, "getCart", - "Returns all items currently in the shopping cart", - func(ctx *ai.ToolContext, input struct{}) ([]string, error) { - sess := session.FromContext[CartState](ctx.Context) - if sess == nil { - return nil, fmt.Errorf("no session in context") - } - return sess.State().Items, nil - }, - ) - - // Define flow that uses session to maintain cart state across requests. - genkit.DefineFlow(g, "manageCart", func(ctx context.Context, input string) (string, error) { - // Load existing session or create new one. - sess, err := session.Load(ctx, store, sessionID) - if err != nil { - // Session doesn't exist, create it. - sess, err = session.New(ctx, - session.WithID[CartState](sessionID), - session.WithStore(store), - session.WithInitialState(CartState{Items: []string{}}), - ) - if err != nil { - return "", err - } - } - - // Attach session to context for tools. - ctx = session.NewContext(ctx, sess) - - return genkit.GenerateText(ctx, g, - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a helpful shopping assistant. Use the provided tools to manage the user's cart."), - ai.WithTools(addToCartTool, getCartTool), - ai.WithPrompt(input), - ) - }) - - // Start server. - mux := http.NewServeMux() - for _, a := range genkit.ListFlows(g) { - mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) - } - log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) -} From bda6531ed6c9380da3078659b4dab69bd41cc639 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 18 Jun 2026 21:10:45 -0700 Subject: [PATCH 120/141] feat(go/exp): detect orphaned detached turns via snapshot heartbeat Ports the JS heartbeat/expired feature (genkit-ai/genkit 411a8ca) and aligns the session store with the JS caller-managed model. Heartbeat / expired: - A detached (background) turn stamps an initial heartbeat on its pending snapshot and refreshes it on an interval. A read surfaces a pending snapshot whose heartbeat has gone stale as the new `expired` status (its background worker is presumed dead); the raw row stays pending, so the status is computed on read and never persisted. - Adds `SessionSnapshot.heartbeatAt` and `SnapshotStatusExpired` to the shared schema, regenerated for Go and Python. Store model (caller-managed timestamps, fewer specialized methods): - SaveSnapshot persists caller-set createdAt/updatedAt verbatim and no longer stamps them; the store still owns snapshotId/sessionId and the empty-status default. This lets a heartbeat carry the existing row through unchanged but for heartbeatAt, so it never bumps updatedAt. - GetLatestSnapshot resolves by createdAt (the file store reads the field rather than file mtime), so a later rewrite of an older row does not move it ahead of a newer-created sibling. - Abort and heartbeat are now ordinary SaveSnapshot mutators (abortPendingSnapshot, beatHeartbeat); the store interface keeps only the optional SnapshotSubscriber (status observation), renamed from SnapshotAborter. AbortSnapshot and RefreshHeartbeat are removed. --- genkit-tools/common/src/types/agent.ts | 19 +- genkit-tools/genkit-schema.json | 6 +- go/ai/exp/agent.go | 178 +++++++++- go/ai/exp/agent_test.go | 315 ++++++++++++++++-- go/ai/exp/gen.go | 26 +- go/ai/exp/localstore/file.go | 132 +++----- go/ai/exp/localstore/file_test.go | 68 ++-- go/ai/exp/localstore/inmemory.go | 46 +-- go/ai/exp/localstore/inmemory_test.go | 39 ++- go/ai/exp/localstore/store_test.go | 171 +++++++++- go/ai/exp/option.go | 2 +- go/ai/exp/session.go | 93 +++--- go/ai/exp/teststore_test.go | 33 +- go/core/schemas.config | 31 +- go/samples/basic-agents/cli.go | 8 +- .../genkit/src/genkit/_core/_typing.py | 2 + 16 files changed, 874 insertions(+), 295 deletions(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 9c53079c3e..36def1d841 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -48,12 +48,17 @@ export type Artifact = z.infer; * `abortSnapshot` companion action while detached. * - `failed`: the invocation terminated with an error. The snapshot's `error` * field describes the failure and resume is rejected with that same error. + * - `expired`: a `pending` snapshot whose detached background worker is + * presumed dead because its heartbeat went stale. Computed on read from a + * stale `heartbeatAt`; never persisted (the dead worker can no longer write + * a terminal status itself). */ export const SnapshotStatusSchema = z.enum([ 'pending', 'completed', 'aborted', 'failed', + 'expired', ]); export type SnapshotStatus = z.infer; @@ -363,8 +368,20 @@ export const SessionSnapshotSchema = z.object({ parentId: z.string().optional(), /** When the snapshot was first written (RFC 3339). */ createdAt: z.string(), - /** When the snapshot was last written (RFC 3339). Equals `createdAt` until rewritten. */ + /** + * When the snapshot's state was last written (RFC 3339). Equals `createdAt` + * until rewritten; a heartbeat refresh on a pending snapshot does not advance + * it, so liveness stays distinct from state changes. + */ updatedAt: z.string().optional(), + /** + * Heartbeat timestamp (RFC 3339) refreshed periodically while a detached + * (background) turn is in flight. Used to detect a dead background worker: + * if a `pending` snapshot's heartbeat goes stale (older than the configured + * timeout), reads surface its status as `expired` (the dead worker can no + * longer persist a terminal status itself). + */ + heartbeatAt: z.string().optional(), /** Lifecycle state of this snapshot. Empty is treated as `completed`. */ status: SnapshotStatusSchema.optional(), /** diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 3f54191bbb..7a82bd85ad 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -263,6 +263,9 @@ "updatedAt": { "type": "string" }, + "heartbeatAt": { + "type": "string" + }, "status": { "$ref": "#/$defs/SnapshotStatus" }, @@ -310,7 +313,8 @@ "pending", "completed", "aborted", - "failed" + "failed", + "expired" ] }, "TurnEnd": { diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 6ff830fa67..922442b845 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -31,6 +31,7 @@ import ( "runtime/debug" "sync" "sync/atomic" + "time" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" @@ -40,6 +41,35 @@ import ( "github.com/google/uuid" ) +// --- Heartbeat --- + +// A detached (background) turn refreshes its pending snapshot's heartbeat on an +// interval so a reader can tell a live background worker from a dead one. Each +// beat is a store write; if the beats stop (the worker died) the heartbeat goes +// stale and reads surface the pending snapshot as [SnapshotStatusExpired] +// rather than leaving it pending forever. +const ( + // defaultHeartbeatInterval is how often a detached turn refreshes its + // pending snapshot's heartbeat. + defaultHeartbeatInterval = 30 * time.Second + // defaultHeartbeatTimeout is the staleness threshold after which a pending + // snapshot whose heartbeat has not advanced is reported as expired on read. + // It is comfortably larger than defaultHeartbeatInterval so a single missed + // beat does not trip expiry. + defaultHeartbeatTimeout = 60 * time.Second +) + +// isHeartbeatExpired reports whether snap is a pending (detached, in-flight) +// snapshot whose heartbeat is older than timeout, i.e. its background worker is +// presumed dead. A pending snapshot that has not yet written a first heartbeat +// is not considered expired (the beat may simply not have fired yet). +func isHeartbeatExpired[State any](snap *SessionSnapshot[State], timeout time.Duration) bool { + if snap.Status != SnapshotStatusPending || snap.HeartbeatAt == nil { + return false + } + return time.Since(*snap.HeartbeatAt) > timeout +} + // --- SessionRunner --- // SessionRunner extends Session with agent-runtime functionality: @@ -289,6 +319,9 @@ func (s *SessionRunner[State]) snapshotTurnEnd(ctx context.Context, finishReason parentID := s.lastSnapshotID sessionID := s.SessionID() + // Timestamps are caller-managed (the store persists them verbatim); a fresh + // turn-end snapshot is created now, so CreatedAt and UpdatedAt are equal. + now := time.Now() saved, err := s.store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { return &SessionSnapshot[State]{ @@ -297,6 +330,8 @@ func (s *SessionRunner[State]) snapshotTurnEnd(ctx context.Context, finishReason Status: SnapshotStatusCompleted, FinishReason: finishReason, State: &state, + CreatedAt: now, + UpdatedAt: now, }, nil }) if err != nil { @@ -420,11 +455,11 @@ func (a *Agent[State]) GetSnapshotAction() api.Action { // which asks the background work behind a pending snapshot (e.g. a // detached invocation) to stop (input [AbortSnapshotRequest], output // [AbortSnapshotResponse]). It returns nil when the agent has no -// [SessionStore] or the store does not implement [SnapshotAborter]. +// [SessionStore] or the store does not implement [SnapshotSubscriber]. // // Use it to expose aborting over a transport (e.g. mount it with -// genkit.Handler next to the agent itself); local Go code should call the -// store's [SnapshotAborter.AbortSnapshot] directly. +// genkit.Handler next to the agent itself); local Go code aborts by writing +// the aborted status through the store's [SnapshotWriter.SaveSnapshot]. func (a *Agent[State]) AbortSnapshotAction() api.Action { return a.abortSnapshot } @@ -436,7 +471,7 @@ func (a *Agent[State]) AbortSnapshotAction() api.Action { // // The store is returned as the [SessionStore] interface, not its concrete // type; a caller needing a store-specific capability (e.g. -// [SnapshotAborter]) type-asserts for it. +// [SnapshotSubscriber]) type-asserts for it. func (a *Agent[State]) Store() SessionStore[State] { return a.store } @@ -656,7 +691,9 @@ func agentMetadataFor[State any](store SessionStore[State]) AgentMetadata { abortable := false if store != nil { mgmt = AgentStateManagementServer - _, abortable = store.(SnapshotAborter) + // Abortable iff the runtime can observe the abort it writes via + // SaveSnapshot, i.e. the store can subscribe to status changes. + _, abortable = store.(SnapshotSubscriber) } return AgentMetadata{ StateManagement: mgmt, @@ -901,17 +938,17 @@ func (rt *agentRuntime[State]) run( // checkDetachCapabilities reports whether the configured store is capable // of supporting detach. Detach requires a writable store (to persist the -// pending snapshot) and a [SnapshotAborter] (which bundles both abort -// triggering and status-change subscription so the runtime can react to -// the abort without polling). +// pending snapshot, and to abort it and refresh its heartbeat via ordinary +// SaveSnapshot writes) and a [SnapshotSubscriber] (so the runtime can observe +// the abort flip and promptly cancel the background work without polling). func (rt *agentRuntime[State]) checkDetachCapabilities() error { if rt.cfg.store == nil { return core.NewError(core.FAILED_PRECONDITION, "agent %q: detach requires a session store", rt.name) } - if _, ok := rt.cfg.store.(SnapshotAborter); !ok { + if _, ok := rt.cfg.store.(SnapshotSubscriber); !ok { return core.NewError(core.FAILED_PRECONDITION, - "agent %q: detach requires a session store implementing SnapshotAborter", rt.name) + "agent %q: detach requires a session store implementing SnapshotSubscriber", rt.name) } return nil } @@ -1060,12 +1097,25 @@ func (rt *agentRuntime[State]) handleDetach( // Detach intends to outlive the client connection. If clientCtx was // already cancelled (or cancels mid-write), we still want the pending // row durable so observers can find it later. Decouple this write. + // + // checkDetachCapabilities (run before detach is honored) guarantees the + // store is a SnapshotSubscriber, so the runtime can observe the abort flip. + subscriber := rt.cfg.store.(SnapshotSubscriber) + + // Stamp the pending row's timestamps and an initial heartbeat (refreshed on + // an interval below). Timestamps are caller-managed; a reader treats a + // pending snapshot whose heartbeat has gone stale as expired (its background + // worker is presumed dead). + now := time.Now() pending, err := rt.cfg.store.SaveSnapshot(context.WithoutCancel(clientCtx), "", func(_ *SessionSnapshot[State]) (*SessionSnapshot[State], error) { return &SessionSnapshot[State]{ - SessionID: sessionID, - ParentID: parentID, - Status: SnapshotStatusPending, + SessionID: sessionID, + ParentID: parentID, + Status: SnapshotStatusPending, + CreatedAt: now, + UpdatedAt: now, + HeartbeatAt: &now, }, nil }) if err != nil { @@ -1078,14 +1128,20 @@ func (rt *agentRuntime[State]) handleDetach( // trashes any further chunks. rt.router.stopAndWait() + // Refresh the heartbeat on an interval, decoupled from clientCtx (the work + // outlives the client connection); stopped when the turn settles or an abort + // lands, both below. + hbCtx, stopHeartbeat := context.WithCancel(context.WithoutCancel(clientCtx)) + go rt.runHeartbeat(hbCtx, pending.SnapshotID) + abortedByUser := &atomic.Bool{} subCtx, stopSub := context.WithCancel(workCtx) - aborter := rt.cfg.store.(SnapshotAborter) // safe: checkDetachCapabilities ran already - statusCh := aborter.OnSnapshotStatusChange(subCtx, pending.SnapshotID) + statusCh := subscriber.OnSnapshotStatusChange(subCtx, pending.SnapshotID) go func() { for status := range statusCh { if status == SnapshotStatusAborted { abortedByUser.Store(true) + stopHeartbeat() cancelWork() return } @@ -1096,6 +1152,10 @@ func (rt *agentRuntime[State]) handleDetach( go func() { res := <-rt.fnDone stopSub() + // The turn has settled; stop refreshing the heartbeat before the + // finalize write so no beat races it. (A stray beat would be a no-op + // anyway: the mutator only touches a still-pending row.) + stopHeartbeat() rt.intake.stopAndWait() rt.router.close() rt.finalizePendingSnapshot(finalizeCtx, pending, res.result, res.err, abortedByUser.Load()) @@ -1112,6 +1172,79 @@ func (rt *agentRuntime[State]) handleDetach( }, nil } +// runHeartbeat refreshes the detached pending snapshot's heartbeat every +// defaultHeartbeatInterval until ctx is cancelled (the turn settled or an abort +// landed). A transient store error is logged and the loop continues; a +// persistently failing worker simply stops beating, which is exactly the +// staleness a reader detects as expired. +func (rt *agentRuntime[State]) runHeartbeat(ctx context.Context, snapshotID string) { + ticker := time.NewTicker(defaultHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := beatHeartbeat(ctx, rt.cfg.store, snapshotID); err != nil { + logger.FromContext(ctx).Debug("agent: heartbeat refresh failed", + "snapshotId", snapshotID, "err", err) + } + } + } +} + +// beatHeartbeat refreshes a pending snapshot's HeartbeatAt via an ordinary +// SaveSnapshot: the mutator carries the existing row through unchanged but for +// HeartbeatAt, so the caller-managed CreatedAt/UpdatedAt are preserved and a +// beat does not register as a state change - no dedicated store method needed. +// It only touches a still-pending row (returning nil otherwise), so a beat +// never resurrects a terminal snapshot or clobbers a concurrent abort/finalize. +// Shared by runHeartbeat and exercised directly in tests. +func beatHeartbeat[State any](ctx context.Context, store SnapshotWriter[State], snapshotID string) error { + now := time.Now() + _, err := store.SaveSnapshot(ctx, snapshotID, + func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + if existing == nil || existing.Status != SnapshotStatusPending { + return nil, nil + } + updated := *existing + updated.HeartbeatAt = &now + return &updated, nil + }) + return err +} + +// abortSnapshot flips a pending snapshot to aborted via an ordinary +// SaveSnapshot and returns the resulting status: aborted when the row was +// pending, the existing terminal status when it was already settled (a no-op +// verbatim rewrite), or "" when the snapshot does not exist. SaveSnapshot's +// atomic read-mutate-write makes the flip safe against a racing terminal write, +// and the status change drives any [SnapshotSubscriber.OnSnapshotStatusChange] +// subscription, so the store needs no dedicated abort method. +func abortPendingSnapshot[State any](ctx context.Context, store SnapshotWriter[State], snapshotID string) (SnapshotStatus, error) { + now := time.Now() + saved, err := store.SaveSnapshot(ctx, snapshotID, + func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + if existing == nil { + return nil, nil // not found + } + if existing.Status != SnapshotStatusPending { + return existing, nil // already terminal: re-persist so the return carries its status + } + updated := *existing + updated.Status = SnapshotStatusAborted + updated.UpdatedAt = now + return &updated, nil + }) + if err != nil { + return "", err + } + if saved == nil { + return "", nil + } + return saved.Status, nil +} + // finalizePendingSnapshot rewrites the pending snapshot row with the // terminal state and status. abortedByUser distinguishes a context // cancellation from abortSnapshot (status=aborted) from an internal @@ -1128,16 +1261,17 @@ func (rt *agentRuntime[State]) finalizePendingSnapshot( ) { finalState := *rt.session.State() // Captured outside the SaveSnapshot callback (which must stay pure): the - // finalizer runs after fn returned, so this is stable. The abort/error + // finalizer runs after fn returned, so these are stable. The abort/error // branches below own their reasons and ignore this clean-success default. completedReason := rt.sess.invocationReason(result) + now := time.Now() _, err := rt.cfg.store.SaveSnapshot(ctx, pending.SnapshotID, func(existing *SessionSnapshot[State]) (*SessionSnapshot[State], error) { // Late abort wins over the terminal we were about to land: keep // the aborted status and whatever state the abort left, but // stamp the aborted finish reason so the snapshot is - // self-describing. (AbortSnapshot only flips status; the runtime + // self-describing. (The abort write only flips status; the runtime // owns the semantic reason.) Skip the write once already stamped. if existing != nil && existing.Status == SnapshotStatusAborted { if existing.FinishReason == AgentFinishReasonAborted { @@ -1145,6 +1279,11 @@ func (rt *agentRuntime[State]) finalizePendingSnapshot( } annotated := *existing annotated.FinishReason = AgentFinishReasonAborted + annotated.UpdatedAt = now + // The row is terminal now; drop the liveness heartbeat so it + // does not linger on a settled snapshot. CreatedAt is preserved + // from the copy, so recency ordering is unaffected. + annotated.HeartbeatAt = nil return &annotated, nil } @@ -1167,6 +1306,9 @@ func (rt *agentRuntime[State]) finalizePendingSnapshot( snapErr = core.AsGenkitError(fnErr) } + // Preserve the pending row's CreatedAt (so the finalize does not + // move it ahead of newer rows in createdAt-ordered resolution) and + // advance UpdatedAt: this rewrite is a real state change. return &SessionSnapshot[State]{ SessionID: pending.SessionID, ParentID: pending.ParentID, @@ -1174,6 +1316,8 @@ func (rt *agentRuntime[State]) finalizePendingSnapshot( FinishReason: finishReason, Error: snapErr, State: &finalState, + CreatedAt: pending.CreatedAt, + UpdatedAt: now, }, nil }) if err != nil { diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index a0060948ce..b5d8108b5a 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -2672,6 +2672,257 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { } } +// getSnapshotViaAction invokes the agent's getSnapshot companion action the +// way a remote client (or the Dev UI) would. Unlike a direct store read it +// resolves compute-on-read status such as expired. +func getSnapshotViaAction[State any](t *testing.T, reg *registry.Registry, agentName, snapshotID string) *SessionSnapshot[State] { + t.Helper() + action := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[State], struct{}]( + reg, api.ActionTypeAgentSnapshot, agentName) + if action == nil { + t.Fatalf("getSnapshot action not registered for %q", agentName) + } + resp, err := action.Run(context.Background(), &GetSnapshotRequest{SnapshotID: snapshotID}, nil) + if err != nil { + t.Fatalf("getSnapshot action: %v", err) + } + return resp +} + +// savePendingWithHeartbeat writes a pending snapshot carrying the given +// heartbeat (nil for none) directly into the store, returning its ID. It +// stands in for an orphaned detached invocation whose worker died without +// finalizing the row. +func savePendingWithHeartbeat(t *testing.T, store *testInMemStore[testState], sessionID string, heartbeat *time.Time) string { + t.Helper() + now := time.Now() + saved, err := store.SaveSnapshot(context.Background(), "", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Status: SnapshotStatusPending, + HeartbeatAt: heartbeat, + CreatedAt: now, + UpdatedAt: now, + State: &SessionState[testState]{SessionID: sessionID, Custom: testState{Counter: 7}}, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot pending: %v", err) + } + return saved.SnapshotID +} + +// defineNoopHeartbeatAgent registers a trivial server-managed agent so its +// getSnapshot companion action exists; the heartbeat compute-on-read tests +// drive that action against snapshots they write into the store directly. +func defineNoopHeartbeatAgent(t *testing.T, reg *registry.Registry, name string, store *testInMemStore[testState]) { + t.Helper() + DefineCustomAgent(reg, name, + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, nil + }) + }, + WithSessionStore(store), + ) +} + +func TestAgent_Detach_StampsHeartbeatOnPendingSnapshot(t *testing.T) { + // A detached invocation's pending snapshot carries an initial heartbeat, + // which the background refresh loop later advances and a reader uses to + // tell a live worker from a dead one. + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + release := make(chan struct{}) + entered := make(chan struct{}) + + af := DefineCustomAgent(reg, "heartbeatStamp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + select { + case entered <- struct{}{}: + case <-ctx.Done(): + } + <-release + sess.AddMessages(ai.NewModelTextMessage("done")) + return nil, nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(context.Background()) + if err != nil { + t.Fatalf("StreamBidi: %v", err) + } + drainInBackground(conn) + + sendText(t, conn, "go") + if err := conn.Detach(); err != nil { + t.Fatalf("Detach: %v", err) + } + + select { + case <-entered: + case <-time.After(2 * time.Second): + t.Fatal("flow did not enter work phase") + } + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected snapshot ID after detach") + } + + pending, err := store.GetSnapshot(context.Background(), out.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot pending: %v", err) + } + if pending == nil { + t.Fatal("pending snapshot not written") + } + if pending.Status != SnapshotStatusPending { + t.Errorf("status = %q, want %q", pending.Status, SnapshotStatusPending) + } + if pending.HeartbeatAt == nil { + t.Error("pending detached snapshot should carry a heartbeatAt") + } + + // Release; the finalizer stops the heartbeat and rewrites the row. + close(release) + waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { + return s.Status == SnapshotStatusCompleted + }) +} + +func TestAgent_GetSnapshotAction_StaleHeartbeatReportsExpired(t *testing.T) { + // An orphaned pending snapshot (heartbeat older than the expiry timeout) is + // surfaced as expired on read, while the raw store row stays pending: + // compute-on-read never writes the expired status back. + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + defineNoopHeartbeatAgent(t, reg, "heartbeatExpired", store) + + stale := time.Now().Add(-2 * defaultHeartbeatTimeout) + id := savePendingWithHeartbeat(t, store, "sess-expired", &stale) + + raw, err := store.GetSnapshot(context.Background(), id) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if raw.Status != SnapshotStatusPending { + t.Errorf("raw store status = %q, want %q (compute-on-read must not write back)", raw.Status, SnapshotStatusPending) + } + + viaAction := getSnapshotViaAction[testState](t, reg, "heartbeatExpired", id) + if viaAction.Status != SnapshotStatusExpired { + t.Errorf("getSnapshot status = %q, want %q", viaAction.Status, SnapshotStatusExpired) + } +} + +func TestAgent_GetSnapshotAction_FreshHeartbeatStaysPending(t *testing.T) { + // A pending snapshot whose heartbeat is fresh is reported as pending: its + // worker is presumed alive. + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + defineNoopHeartbeatAgent(t, reg, "heartbeatFresh", store) + + fresh := time.Now() + id := savePendingWithHeartbeat(t, store, "sess-fresh", &fresh) + + viaAction := getSnapshotViaAction[testState](t, reg, "heartbeatFresh", id) + if viaAction.Status != SnapshotStatusPending { + t.Errorf("getSnapshot status = %q, want %q", viaAction.Status, SnapshotStatusPending) + } +} + +func TestAgent_GetSnapshotAction_NoHeartbeatStaysPending(t *testing.T) { + // A pending snapshot that has not written a first heartbeat yet is not + // expired: the beat may simply not have fired. + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + defineNoopHeartbeatAgent(t, reg, "heartbeatNone", store) + + id := savePendingWithHeartbeat(t, store, "sess-noheartbeat", nil) + + viaAction := getSnapshotViaAction[testState](t, reg, "heartbeatNone", id) + if viaAction.Status != SnapshotStatusPending { + t.Errorf("getSnapshot status = %q, want %q", viaAction.Status, SnapshotStatusPending) + } +} + +func TestAgent_Heartbeat_BeatAdvancesHeartbeatNotUpdatedAt(t *testing.T) { + // A beat advances the liveness timestamp but must NOT advance UpdatedAt: a + // heartbeat is not a state change, which is the whole reason HeartbeatAt is + // a field of its own rather than reusing UpdatedAt. + store := newTestInMemStore[testState]() + old := time.Now().Add(-time.Hour) + id := savePendingWithHeartbeat(t, store, "sess", &old) + before, err := store.GetSnapshot(context.Background(), id) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + + if err := beatHeartbeat(context.Background(), store, id); err != nil { + t.Fatalf("beatHeartbeat: %v", err) + } + + after, err := store.GetSnapshot(context.Background(), id) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if after.HeartbeatAt == nil || !after.HeartbeatAt.After(*before.HeartbeatAt) { + t.Errorf("HeartbeatAt did not advance: before=%v after=%v", before.HeartbeatAt, after.HeartbeatAt) + } + if !after.UpdatedAt.Equal(before.UpdatedAt) { + t.Errorf("UpdatedAt changed on heartbeat: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) + } + if after.Status != SnapshotStatusPending { + t.Errorf("status = %q, want %q", after.Status, SnapshotStatusPending) + } +} + +func TestAgent_Heartbeat_BeatIsNoopOnTerminalSnapshot(t *testing.T) { + // A stray beat against a settled snapshot must never resurrect or mutate + // it: the mutator's pending-guard returns nil, so no write happens. + store := newTestInMemStore[testState]() + saved, err := store.SaveSnapshot(context.Background(), "", + func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{ + Status: SnapshotStatusCompleted, + State: &SessionState[testState]{SessionID: "sess"}, + }, nil + }) + if err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + before, err := store.GetSnapshot(context.Background(), saved.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + + if err := beatHeartbeat(context.Background(), store, saved.SnapshotID); err != nil { + t.Fatalf("beatHeartbeat: %v", err) + } + + after, err := store.GetSnapshot(context.Background(), saved.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if after.Status != SnapshotStatusCompleted { + t.Errorf("status = %q, want %q (beat must not change it)", after.Status, SnapshotStatusCompleted) + } + if after.HeartbeatAt != nil { + t.Errorf("beat stamped a heartbeat on a terminal snapshot: %v", after.HeartbeatAt) + } + if !after.UpdatedAt.Equal(before.UpdatedAt) { + t.Errorf("beat bumped UpdatedAt on a terminal snapshot: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) + } +} + func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { // SendArtifact must behave the same way regardless of whether detach // has landed: the artifact is added to the session and shows up in @@ -2851,7 +3102,7 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { // Abort via the store. The local caller already has the store // reference from WithSessionStore. - status, err := store.AbortSnapshot(context.Background(), out.SnapshotID) + status, err := abortPendingSnapshot(context.Background(), store, out.SnapshotID) if err != nil { t.Fatalf("AbortSnapshot: %v", err) } @@ -3194,12 +3445,15 @@ func TestAgent_GetSnapshotAction_BySessionID(t *testing.T) { // The session-ID lookup returns the latest row whatever its status, so a // reconnecting client can observe a failed/pending tip (unlike resume, // which rejects it). + failedAt := time.Now() failed, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ SessionID: out1.SessionID, ParentID: out2.SnapshotID, Status: SnapshotStatusFailed, FinishReason: AgentFinishReasonFailed, + CreatedAt: failedAt, + UpdatedAt: failedAt, }, nil }) if err != nil { @@ -3423,7 +3677,7 @@ func (s wrongSessionStore[State]) GetLatestSnapshot(ctx context.Context, session return s.SessionStore.GetSnapshot(ctx, s.snapshotID) } -// minimalStore is a SessionStore that does NOT implement SnapshotAborter. +// minimalStore is a SessionStore that does NOT implement SnapshotSubscriber. // Used to verify the abort action stays unregistered for stores that // lack the capability. type minimalStore[State any] struct{} @@ -3570,11 +3824,11 @@ func TestAgent_AgentMetadata_StateSchema(t *testing.T) { func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { // Verify the abort companion action is only registered when the - // store implements SnapshotAborter. The getSnapshot action is + // store implements SnapshotSubscriber. The getSnapshot action is // registered regardless. t.Run("aborter capability → both registered", func(t *testing.T) { reg := newTestRegistry(t) - store := newTestInMemStore[testState]() // implements SnapshotAborter + store := newTestInMemStore[testState]() // implements SnapshotSubscriber DefineCustomAgent(reg, "fullCaps", func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { return nil, nil @@ -3589,7 +3843,7 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { abortAction := core.ResolveActionFor[*AbortSnapshotRequest, *AbortSnapshotResponse, struct{}]( reg, api.ActionTypeAgentAbort, "fullCaps") if abortAction == nil { - t.Error("abortSnapshot action should be registered when store implements SnapshotAborter") + t.Error("abortSnapshot action should be registered when store implements SnapshotSubscriber") } }) @@ -3604,12 +3858,12 @@ func TestAgent_AbortAction_GatedOnCapabilities(t *testing.T) { getAction := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}]( reg, api.ActionTypeAgentSnapshot, "minCaps") if getAction == nil { - t.Error("getSnapshot action should be registered even when store lacks SnapshotAborter") + t.Error("getSnapshot action should be registered even when store lacks SnapshotSubscriber") } abortAction := core.ResolveActionFor[*AbortSnapshotRequest, *AbortSnapshotResponse, struct{}]( reg, api.ActionTypeAgentAbort, "minCaps") if abortAction != nil { - t.Error("abortSnapshot action should NOT be registered when store lacks SnapshotAborter") + t.Error("abortSnapshot action should NOT be registered when store lacks SnapshotSubscriber") } }) } @@ -3677,8 +3931,8 @@ func TestAgent_Store(t *testing.T) { } // The returned store is usable directly, and store-specific // capabilities are reachable by type assertion. - if _, ok := af.Store().(SnapshotAborter); !ok { - t.Error("expected the configured store to satisfy SnapshotAborter") + if _, ok := af.Store().(SnapshotSubscriber); !ok { + t.Error("expected the configured store to satisfy SnapshotSubscriber") } }) @@ -3926,29 +4180,34 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { } } -func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { +func TestAbortPendingSnapshot_AtomicAndIdempotent(t *testing.T) { + // Aborting is an ordinary SaveSnapshot flip (abortPendingSnapshot); this + // exercises its CAS semantics against the store's atomic read-mutate-write. ctx := context.Background() store := newTestInMemStore[testState]() // Abort on missing snapshot returns empty status, no error. - if status, err := store.AbortSnapshot(ctx, "nope"); err != nil || status != "" { - t.Fatalf("AbortSnapshot(missing) = %q, %v; want \"\", nil", status, err) + if status, err := abortPendingSnapshot(ctx, store, "nope"); err != nil || status != "" { + t.Fatalf("abort(missing) = %q, %v; want \"\", nil", status, err) } // Pending → aborted, UpdatedAt advances (verified via GetSnapshot). + seedNow := time.Now() pending, err := store.SaveSnapshot(ctx, "snap-cas", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ - Status: SnapshotStatusPending, + Status: SnapshotStatusPending, + CreatedAt: seedNow, + UpdatedAt: seedNow, }, nil }) if err != nil { t.Fatalf("SaveSnapshot: %v", err) } time.Sleep(time.Millisecond) // ensure measurable UpdatedAt delta - status, err := store.AbortSnapshot(ctx, "snap-cas") + status, err := abortPendingSnapshot(ctx, store, "snap-cas") if err != nil { - t.Fatalf("AbortSnapshot: %v", err) + t.Fatalf("abort: %v", err) } if status != SnapshotStatusAborted { t.Errorf("status after first abort = %q, want aborted", status) @@ -3963,9 +4222,9 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { // Idempotent: second abort returns aborted, no error, no further mutation. firstUpdate := afterFirst.UpdatedAt - status2, err := store.AbortSnapshot(ctx, "snap-cas") + status2, err := abortPendingSnapshot(ctx, store, "snap-cas") if err != nil { - t.Fatalf("AbortSnapshot (second): %v", err) + t.Fatalf("abort (second): %v", err) } if status2 != SnapshotStatusAborted { t.Errorf("status after second abort = %q, want aborted", status2) @@ -3987,9 +4246,9 @@ func TestInMemorySessionStore_AbortSnapshot_AtomicAndIdempotent(t *testing.T) { }); err != nil { t.Fatalf("SaveSnapshot: %v", err) } - status3, err := store.AbortSnapshot(ctx, "snap-complete") + status3, err := abortPendingSnapshot(ctx, store, "snap-complete") if err != nil { - t.Fatalf("AbortSnapshot on complete: %v", err) + t.Fatalf("abort on complete: %v", err) } if status3 != SnapshotStatusCompleted { t.Errorf("abort on complete returned status=%q, want completed", status3) @@ -4043,7 +4302,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { } // Externally abort before releasing fn. - if _, err := store.AbortSnapshot(context.Background(), out.SnapshotID); err != nil { + if _, err := abortPendingSnapshot(context.Background(), store, out.SnapshotID); err != nil { t.Fatalf("AbortSnapshot: %v", err) } @@ -4098,7 +4357,7 @@ func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { } // Abort flips status; subscriber observes aborted. - if _, err := store.AbortSnapshot(ctx, "snap-sub"); err != nil { + if _, err := abortPendingSnapshot(ctx, store, "snap-sub"); err != nil { t.Fatalf("AbortSnapshot: %v", err) } select { @@ -4147,7 +4406,7 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { t.Fatalf("RunText: %v", err) } - status, err := store.AbortSnapshot(ctx, out.SnapshotID) + status, err := abortPendingSnapshot(ctx, store, out.SnapshotID) if err != nil { t.Fatalf("AbortSnapshot: %v", err) } @@ -4672,7 +4931,7 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { if err != nil { t.Fatalf("Output: %v", err) } - if _, err := store.AbortSnapshot(context.Background(), out.SnapshotID); err != nil { + if _, err := abortPendingSnapshot(context.Background(), store, out.SnapshotID); err != nil { t.Fatalf("AbortSnapshot: %v", err) } // AbortSnapshot flips status=aborted (finishReason still empty); the @@ -5195,13 +5454,17 @@ func TestAgent_ResumeFromSessionID_FailedTipRejected(t *testing.T) { t.Fatalf("RunText: %v", err) } // A failed detach-style row chained off the tip, as a background - // invocation that failed would leave behind. + // invocation that failed would leave behind. Created after out1, so it is + // the session's latest by CreatedAt. + failedAt := time.Now() if _, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ SessionID: out1.SessionID, ParentID: out1.SnapshotID, Status: SnapshotStatusFailed, FinishReason: AgentFinishReasonFailed, + CreatedAt: failedAt, + UpdatedAt: failedAt, }, nil }); err != nil { t.Fatalf("SaveSnapshot failed row: %v", err) @@ -5353,11 +5616,15 @@ func TestAgent_ResumeFromSessionID_PendingTipRejected(t *testing.T) { if err != nil { t.Fatalf("RunText: %v", err) } + // Chained off the tip and created later, so it is the session's latest. + pendingAt := time.Now() if _, err := store.SaveSnapshot(ctx, "", func(_ *SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { return &SessionSnapshot[testState]{ SessionID: out1.SessionID, ParentID: out1.SnapshotID, Status: SnapshotStatusPending, + CreatedAt: pendingAt, + UpdatedAt: pendingAt, }, nil }); err != nil { t.Fatalf("SaveSnapshot pending row: %v", err) diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 4953518d4c..50d991fe41 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -148,7 +148,7 @@ type ToolResume struct { // when the store doesn't support it). type AgentMetadata struct { // Abortable reports whether the agent's invocations can be aborted - // (true when the store implements [SnapshotAborter]). + // (true when the store implements [SnapshotSubscriber]). Abortable bool `json:"abortable,omitempty"` // StateManagement reports who owns session state. StateManagement AgentStateManagement `json:"stateManagement,omitempty"` @@ -330,6 +330,14 @@ type SessionSnapshot[State any] struct { // background task can report how it ended without re-deriving it from the // messages. FinishReason AgentFinishReason `json:"finishReason,omitempty"` + // HeartbeatAt is refreshed periodically while a detached (background) turn is + // in flight, so a reader can detect a dead background worker: if a pending + // snapshot's heartbeat goes stale (older than the configured timeout), reads + // surface its status as [SnapshotStatusExpired] (the dead worker can no longer + // persist a terminal status itself). Refreshing it does not advance + // [SessionSnapshot.UpdatedAt], so liveness stays distinct from state changes. Nil + // on snapshots that never detached. + HeartbeatAt *time.Time `json:"heartbeatAt,omitempty"` // ParentID is the ID of the previous snapshot in this timeline. It is // informational lineage (for debugging and UI history trees) and plays // no part in resolving a session's latest snapshot. @@ -349,9 +357,10 @@ type SessionSnapshot[State any] struct { // Status is the lifecycle state of this snapshot. Empty is treated as // [SnapshotStatusCompleted] for backwards compatibility. Status SnapshotStatus `json:"status,omitempty"` - // UpdatedAt is when the snapshot was last written. For pending snapshots - // it equals CreatedAt; once the snapshot is finalized it reflects the - // terminal write. + // UpdatedAt is when the snapshot's state was last written. A heartbeat refresh on + // a pending snapshot deliberately does not advance it, so a pending snapshot's + // UpdatedAt equals CreatedAt until the snapshot is finalized, whose terminal write + // advances it. UpdatedAt time.Time `json:"updatedAt,omitempty"` } @@ -385,6 +394,10 @@ type SessionState[State any] struct { // [SnapshotStatusFailed] when the agent finishes, or with // [SnapshotStatusAborted] if the client called abortSnapshot in the // meantime. +// +// [SnapshotStatusExpired] is never persisted: it is computed on read for a +// pending snapshot whose background worker is presumed dead (its heartbeat +// went stale), surfacing the orphan rather than leaving it pending forever. type SnapshotStatus string const ( @@ -401,6 +414,11 @@ const ( // The snapshot's Error field describes the failure and resume is // rejected with that same error. SnapshotStatusFailed SnapshotStatus = "failed" + // SnapshotStatusExpired indicates a pending snapshot whose detached background + // worker is presumed dead: its [SessionSnapshot.HeartbeatAt] went stale. It is + // computed on read (never persisted), so the raw store row stays + // [SnapshotStatusPending] while a read surfaces it as expired. + SnapshotStatusExpired SnapshotStatus = "expired" ) // TurnEnd groups the signals emitted when an agent turn finishes. diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index 88f67615b6..ebee7fd1a2 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -108,16 +108,11 @@ func (s *FileSessionStore[State]) SaveSnapshot( } next.SnapshotID = id - now := time.Now() - if existing != nil { - next.CreatedAt = existing.CreatedAt - if existing.SessionID != "" { - next.SessionID = existing.SessionID // a row's session never changes - } - } else { - next.CreatedAt = now + // SessionID is preserved (a row's session never changes); CreatedAt, + // UpdatedAt, and HeartbeatAt are caller-managed and persisted verbatim. + if existing != nil && existing.SessionID != "" { + next.SessionID = existing.SessionID } - next.UpdatedAt = now if next.Status == "" { next.Status = exp.SnapshotStatusCompleted } @@ -131,29 +126,36 @@ func (s *FileSessionStore[State]) SaveSnapshot( return next, nil } -// snapshotHeader is the subset of snapshot fields needed to match a row to -// a session during the latest-snapshot scan. Decoding only these avoids -// materializing every row's full conversation state during the scan. +// snapshotHeader is the subset of snapshot fields needed to pick a session's +// latest row during the scan. Decoding only these avoids materializing every +// row's full conversation state; only the winning row is fully decoded. type snapshotHeader struct { - SessionID string `json:"sessionId"` + SessionID string `json:"sessionId"` + CreatedAt time.Time `json:"createdAt"` } -// GetLatestSnapshot returns the session's most recently updated snapshot +// GetLatestSnapshot returns the session's most recently created snapshot // regardless of status, per the [exp.SnapshotReader.GetLatestSnapshot] // contract. // -// Recency is judged by file mtime, which for snapshots written by this package -// advances with [exp.SessionSnapshot.UpdatedAt]; if a file is touched -// externally, mtime wins. A file that fails to parse or vanishes mid-scan is -// skipped, so one corrupted row cannot hide every other session. +// Recency is judged by the [exp.SessionSnapshot.CreatedAt] field (not file +// mtime), so a later rewrite of an older row - which preserves CreatedAt - does +// not move it ahead of a newer-created sibling. Ties are broken by snapshot ID. +// A file that fails to parse or vanishes mid-scan is skipped, so one corrupted +// row cannot hide every other session. func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { if sessionID == "" { return nil, errors.New("FileSessionStore: session ID is empty") } - names, err := s.snapshotFilesNewestFirst() + names, err := s.snapshotFileNames() if err != nil { return nil, err } + var ( + bestName string + bestAt time.Time + found bool + ) for _, name := range names { s.mu.Lock() data, err := os.ReadFile(filepath.Join(s.dir, name)) @@ -168,24 +170,38 @@ func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID if h.SessionID != sessionID { continue } - var snap exp.SessionSnapshot[State] - if err := json.Unmarshal(data, &snap); err != nil { - continue + // Most recently created wins; the file name is ".json", so + // a name compare is a deterministic SnapshotID tie-break. + if !found || h.CreatedAt.After(bestAt) || + (h.CreatedAt.Equal(bestAt) && name > bestName) { + bestName, bestAt, found = name, h.CreatedAt, true } - return &snap, nil } - return nil, nil + if !found { + return nil, nil + } + // Fully decode only the winner. CreatedAt is preserved across rewrites, so a + // concurrent rewrite of this row between scan and read still yields the + // right snapshot (with possibly fresher state). + s.mu.Lock() + data, err := os.ReadFile(filepath.Join(s.dir, bestName)) + s.mu.Unlock() + if err != nil { + return nil, nil + } + var snap exp.SessionSnapshot[State] + if err := json.Unmarshal(data, &snap); err != nil { + return nil, nil + } + return &snap, nil } -// snapshotFilesNewestFirst returns the names of the directory's snapshot -// files (non-directory *.json entries; writeLocked's ".*.tmp" temp -// files never match) sorted by modification time, newest first, with -// name as a deterministic tie-break. Entries that vanish between the -// directory read and the stat are skipped. Returns nil if the directory -// does not exist. The listing is not atomic with respect to concurrent -// writes; a snapshot that appears or disappears mid-scan may or may not -// be observed. -func (s *FileSessionStore[State]) snapshotFilesNewestFirst() ([]string, error) { +// snapshotFileNames returns the names of the directory's snapshot files +// (non-directory *.json entries; writeLocked's ".*.tmp" temp files never +// match). Returns nil if the directory does not exist. The listing is not +// atomic with respect to concurrent writes; a snapshot that appears or +// disappears mid-scan may or may not be observed. +func (s *FileSessionStore[State]) snapshotFileNames() ([]string, error) { entries, err := os.ReadDir(s.dir) if err != nil { if errors.Is(err, os.ErrNotExist) { @@ -193,62 +209,16 @@ func (s *FileSessionStore[State]) snapshotFilesNewestFirst() ([]string, error) { } return nil, fmt.Errorf("FileSessionStore: list dir: %w", err) } - type candidate struct { - name string - modTime time.Time - } - var cands []candidate + var names []string for _, e := range entries { if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") { continue } - info, err := e.Info() - if err != nil { - continue - } - cands = append(cands, candidate{e.Name(), info.ModTime()}) - } - slices.SortFunc(cands, func(a, b candidate) int { - if c := b.modTime.Compare(a.modTime); c != 0 { // newest first - return c - } - return strings.Compare(b.name, a.name) - }) - names := make([]string, len(cands)) - for i, c := range cands { - names[i] = c.name + names = append(names, e.Name()) } return names, nil } -// AbortSnapshot atomically flips a pending snapshot to aborted. If the -// snapshot is already terminal the existing status is returned unchanged. -// Returns an empty status if the snapshot is not found. -func (s *FileSessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (exp.SnapshotStatus, error) { - if err := validateSnapshotID(snapshotID); err != nil { - return "", err - } - s.mu.Lock() - defer s.mu.Unlock() - - snap, err := s.readLocked(snapshotID) - if err != nil { - return "", err - } - if snap == nil { - return "", nil - } - if snap.Status == exp.SnapshotStatusPending { - snap.Status = exp.SnapshotStatusAborted - snap.UpdatedAt = time.Now() - if err := s.writeLocked(snap); err != nil { - return "", err - } - s.notifyLocked(snapshotID, snap.Status) - } - return snap.Status, nil -} - // OnSnapshotStatusChange subscribes to status changes for a snapshot. The // returned channel yields the current status (if any) and any subsequent // changes triggered by calls on this store instance, until ctx is cancelled. diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index df5bd98789..c356f96951 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -66,14 +66,17 @@ func TestFileSessionStore(t *testing.T) { t.Run("SaveWithFixedID", func(t *testing.T) { store := newFileStore(t) + now := time.Now() saved, err := store.SaveSnapshot(context.Background(), "snap-1", func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { if existing != nil { t.Errorf("expected nil existing on first save, got %+v", existing) } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusCompleted, - State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + Status: exp.SnapshotStatusCompleted, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + CreatedAt: now, + UpdatedAt: now, }, nil }) if err != nil { @@ -82,9 +85,10 @@ func TestFileSessionStore(t *testing.T) { if saved.SnapshotID != "snap-1" { t.Errorf("saved SnapshotID = %q, want %q", saved.SnapshotID, "snap-1") } - if saved.CreatedAt.IsZero() || saved.UpdatedAt.IsZero() { - t.Errorf("expected CreatedAt/UpdatedAt stamped, got created=%v updated=%v", - saved.CreatedAt, saved.UpdatedAt) + // Timestamps are caller-managed: the store persists them verbatim. + if !saved.CreatedAt.Equal(now) || !saved.UpdatedAt.Equal(now) { + t.Errorf("expected caller-set timestamps persisted, got created=%v updated=%v want %v", + saved.CreatedAt, saved.UpdatedAt, now) } }) @@ -160,24 +164,31 @@ func TestFileSessionStore(t *testing.T) { } }) - t.Run("PreservesCreatedAtOnUpdate", func(t *testing.T) { + t.Run("PersistsCallerTimestamps", func(t *testing.T) { + // Timestamps are caller-managed: the store round-trips them verbatim + // (it does not stamp). The caller preserves CreatedAt and advances + // UpdatedAt across a rewrite. store := newFileStore(t) + created := time.Now() saved, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted, CreatedAt: created, UpdatedAt: created}, nil }) if err != nil { t.Fatalf("seed: %v", err) } time.Sleep(time.Millisecond) + later := time.Now() updated, err := store.SaveSnapshot(context.Background(), "snap-1", func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { if existing == nil { t.Fatal("expected non-nil existing on update") } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusCompleted, - State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, + Status: exp.SnapshotStatusCompleted, + State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, + CreatedAt: existing.CreatedAt, + UpdatedAt: later, }, nil }) if err != nil { @@ -199,11 +210,7 @@ func TestFileSessionStore(t *testing.T) { }); err != nil { t.Fatalf("seed: %v", err) } - status, err := store.AbortSnapshot(context.Background(), "snap-1") - if err != nil { - t.Fatalf("AbortSnapshot: %v", err) - } - if status != exp.SnapshotStatusAborted { + if status := abortViaSave(t, store, "snap-1"); status != exp.SnapshotStatusAborted { t.Errorf("status = %q, want %q", status, exp.SnapshotStatusAborted) } snap, _ := store.GetSnapshot(context.Background(), "snap-1") @@ -220,22 +227,14 @@ func TestFileSessionStore(t *testing.T) { }); err != nil { t.Fatalf("seed: %v", err) } - status, err := store.AbortSnapshot(context.Background(), "snap-1") - if err != nil { - t.Fatalf("AbortSnapshot: %v", err) - } - if status != exp.SnapshotStatusCompleted { + if status := abortViaSave(t, store, "snap-1"); status != exp.SnapshotStatusCompleted { t.Errorf("status = %q, want %q (no-op on terminal)", status, exp.SnapshotStatusCompleted) } }) t.Run("AbortMissingReturnsEmpty", func(t *testing.T) { store := newFileStore(t) - status, err := store.AbortSnapshot(context.Background(), "nonexistent") - if err != nil { - t.Fatalf("AbortSnapshot: %v", err) - } - if status != "" { + if status := abortViaSave(t, store, "nonexistent"); status != "" { t.Errorf("status = %q, want empty (not found)", status) } }) @@ -261,9 +260,7 @@ func TestFileSessionStore(t *testing.T) { t.Fatal("timeout waiting for initial status") } - if _, err := store.AbortSnapshot(context.Background(), "snap-1"); err != nil { - t.Fatalf("AbortSnapshot: %v", err) - } + abortViaSave(t, store, "snap-1") select { case s := <-ch: if s != exp.SnapshotStatusAborted { @@ -368,9 +365,6 @@ func TestFileSessionStore(t *testing.T) { }); err == nil { t.Errorf("SaveSnapshot(%q): expected error, got nil", id) } - if _, err := store.AbortSnapshot(context.Background(), id); err == nil { - t.Errorf("AbortSnapshot(%q): expected error, got nil", id) - } }) } }) @@ -392,9 +386,9 @@ func TestFileSessionStore(t *testing.T) { } }) - t.Run("ImplementsSessionStoreAndAborter", func(t *testing.T) { + t.Run("ImplementsSessionStoreAndSubscriber", func(t *testing.T) { var _ exp.SessionStore[testState] = (*FileSessionStore[testState])(nil) - var _ exp.SnapshotAborter = (*FileSessionStore[testState])(nil) + var _ exp.SnapshotSubscriber = (*FileSessionStore[testState])(nil) }) } @@ -445,6 +439,16 @@ func TestFileSessionStore_SessionIDs(t *testing.T) { }) } +func TestFileSessionStore_Heartbeat(t *testing.T) { + runHeartbeatStoreTests(t, func(t *testing.T) exp.SessionStore[testState] { + store, err := NewFileSessionStore[testState](t.TempDir()) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + return store + }) +} + func TestFileSessionStore_GetLatestSnapshot_SkipsUnparseableFiles(t *testing.T) { // A stray unparseable .json file (crash artifact, partial copy, // hand-edited row) must not poison session resolution for the healthy diff --git a/go/ai/exp/localstore/inmemory.go b/go/ai/exp/localstore/inmemory.go index de56dd9bf3..15df0ec54f 100644 --- a/go/ai/exp/localstore/inmemory.go +++ b/go/ai/exp/localstore/inmemory.go @@ -27,7 +27,6 @@ import ( "fmt" "slices" "sync" - "time" "github.com/firebase/genkit/go/ai/exp" "github.com/google/uuid" @@ -37,7 +36,7 @@ import ( // is lost when the process exits; use [FileSessionStore] or a real backend // when persistence is needed. // -// It implements [exp.SessionStore] and [exp.SnapshotAborter]. +// It implements [exp.SessionStore] and [exp.SnapshotSubscriber]. type InMemorySessionStore[State any] struct { // mu is RWMutex so GetSnapshot (which JSON-marshals while holding the // lock) can run concurrently with other readers. All writers (Save, @@ -66,9 +65,9 @@ func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID return copySnapshot(snap) } -// GetLatestSnapshot returns the session's most recently updated snapshot +// GetLatestSnapshot returns the session's most recently created snapshot // regardless of status, per the [exp.SnapshotReader.GetLatestSnapshot] -// contract. Ties on UpdatedAt are broken by SnapshotID so resolution is +// contract. Ties on CreatedAt are broken by SnapshotID so resolution is // deterministic. The returned snapshot is a deep copy. func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { if sessionID == "" { @@ -81,8 +80,8 @@ func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessi if snap.SessionID != sessionID { continue } - if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) || - (snap.UpdatedAt.Equal(latest.UpdatedAt) && snap.SnapshotID > latest.SnapshotID) { + if latest == nil || snap.CreatedAt.After(latest.CreatedAt) || + (snap.CreatedAt.Equal(latest.CreatedAt) && snap.SnapshotID > latest.SnapshotID) { latest = snap } } @@ -92,24 +91,6 @@ func (s *InMemorySessionStore[State]) GetLatestSnapshot(_ context.Context, sessi return copySnapshot(latest) } -// AbortSnapshot atomically flips a pending snapshot to aborted. If the -// snapshot is already terminal the existing status is returned unchanged. -// Returns an empty status if the snapshot is not found. -func (s *InMemorySessionStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (exp.SnapshotStatus, error) { - s.mu.Lock() - defer s.mu.Unlock() - snap, ok := s.snapshots[snapshotID] - if !ok { - return "", nil - } - if snap.Status == exp.SnapshotStatusPending { - snap.Status = exp.SnapshotStatusAborted - snap.UpdatedAt = time.Now() - s.notifyLocked(snapshotID, snap.Status) - } - return snap.Status, nil -} - // SaveSnapshot atomically reads, applies fn, and persists. See // [exp.SnapshotWriter] for the full contract; this implementation calls fn // exactly once per call. @@ -143,16 +124,11 @@ func (s *InMemorySessionStore[State]) SaveSnapshot( } next.SnapshotID = id - now := time.Now() - if existing != nil { - next.CreatedAt = existing.CreatedAt - if existing.SessionID != "" { - next.SessionID = existing.SessionID // a row's session never changes - } - } else { - next.CreatedAt = now + // SessionID is preserved (a row's session never changes); CreatedAt, + // UpdatedAt, and HeartbeatAt are caller-managed and persisted verbatim. + if existing != nil && existing.SessionID != "" { + next.SessionID = existing.SessionID } - next.UpdatedAt = now if next.Status == "" { next.Status = exp.SnapshotStatusCompleted } @@ -167,9 +143,7 @@ func (s *InMemorySessionStore[State]) SaveSnapshot( } // Return next (the freshly-allocated struct from fn) rather than // copied: copied is the pointer the store retains, so returning it - // would alias the caller's view with the stored row and let future - // in-place mutations (e.g. AbortSnapshot updating UpdatedAt) leak - // through. + // would alias the caller's view with the stored row. return next, nil } diff --git a/go/ai/exp/localstore/inmemory_test.go b/go/ai/exp/localstore/inmemory_test.go index 32321ff273..19c61c0ec4 100644 --- a/go/ai/exp/localstore/inmemory_test.go +++ b/go/ai/exp/localstore/inmemory_test.go @@ -38,14 +38,17 @@ func TestInMemorySessionStore(t *testing.T) { t.Run("SaveWithFixedID", func(t *testing.T) { store := NewInMemorySessionStore[testState]() + now := time.Now() saved, err := store.SaveSnapshot(context.Background(), "snap-1", func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { if existing != nil { t.Errorf("expected nil existing on first save, got %+v", existing) } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusCompleted, - State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + Status: exp.SnapshotStatusCompleted, + State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + CreatedAt: now, + UpdatedAt: now, }, nil }) if err != nil { @@ -54,9 +57,10 @@ func TestInMemorySessionStore(t *testing.T) { if saved.SnapshotID != "snap-1" { t.Errorf("saved SnapshotID = %q, want %q", saved.SnapshotID, "snap-1") } - if saved.CreatedAt.IsZero() || saved.UpdatedAt.IsZero() { - t.Errorf("expected CreatedAt/UpdatedAt stamped, got created=%v updated=%v", - saved.CreatedAt, saved.UpdatedAt) + // Timestamps are caller-managed: the store persists them verbatim. + if !saved.CreatedAt.Equal(now) || !saved.UpdatedAt.Equal(now) { + t.Errorf("expected caller-set timestamps persisted, got created=%v updated=%v want %v", + saved.CreatedAt, saved.UpdatedAt, now) } }) @@ -121,24 +125,31 @@ func TestInMemorySessionStore(t *testing.T) { } }) - t.Run("PreservesCreatedAtOnUpdate", func(t *testing.T) { + t.Run("PersistsCallerTimestamps", func(t *testing.T) { + // Timestamps are caller-managed: the store round-trips them verbatim + // (it does not stamp). The caller preserves CreatedAt and advances + // UpdatedAt across a rewrite. store := NewInMemorySessionStore[testState]() + created := time.Now() saved, err := store.SaveSnapshot(context.Background(), "snap-1", func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { - return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted}, nil + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusCompleted, CreatedAt: created, UpdatedAt: created}, nil }) if err != nil { t.Fatalf("seed: %v", err) } time.Sleep(time.Millisecond) // ensure measurable UpdatedAt delta + later := time.Now() updated, err := store.SaveSnapshot(context.Background(), "snap-1", func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { if existing == nil { t.Fatal("expected non-nil existing on update") } return &exp.SessionSnapshot[testState]{ - Status: exp.SnapshotStatusCompleted, - State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, + Status: exp.SnapshotStatusCompleted, + State: &exp.SessionState[testState]{Custom: testState{Counter: 2}}, + CreatedAt: existing.CreatedAt, + UpdatedAt: later, }, nil }) if err != nil { @@ -152,9 +163,15 @@ func TestInMemorySessionStore(t *testing.T) { } }) - t.Run("ImplementsSessionStoreAndAborter", func(t *testing.T) { + t.Run("ImplementsSessionStoreAndSubscriber", func(t *testing.T) { var _ exp.SessionStore[testState] = (*InMemorySessionStore[testState])(nil) - var _ exp.SnapshotAborter = (*InMemorySessionStore[testState])(nil) + var _ exp.SnapshotSubscriber = (*InMemorySessionStore[testState])(nil) + }) +} + +func TestInMemorySessionStore_Heartbeat(t *testing.T) { + runHeartbeatStoreTests(t, func(t *testing.T) exp.SessionStore[testState] { + return NewInMemorySessionStore[testState]() }) } diff --git a/go/ai/exp/localstore/store_test.go b/go/ai/exp/localstore/store_test.go index 26bf83830b..d5adf1c8d0 100644 --- a/go/ai/exp/localstore/store_test.go +++ b/go/ai/exp/localstore/store_test.go @@ -39,6 +39,8 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio saveRow := func(t *testing.T, store exp.SessionStore[testState], id, sessionID, parentID string, status exp.SnapshotStatus) *exp.SessionSnapshot[testState] { t.Helper() + // Timestamps are caller-managed; a fresh row is created now. + now := time.Now() saved, err := store.SaveSnapshot(ctx, id, func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { return &exp.SessionSnapshot[testState]{ @@ -46,6 +48,8 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio ParentID: parentID, Status: status, State: &exp.SessionState[testState]{Custom: testState{Counter: 1}}, + CreatedAt: now, + UpdatedAt: now, }, nil }) if err != nil { @@ -54,9 +58,8 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio return saved } - // tick spaces consecutive writes far enough apart that UpdatedAt (and - // the file store's mtimes) order them unambiguously even on coarse - // clocks. + // tick spaces consecutive writes far enough apart that CreatedAt orders + // them unambiguously even on coarse clocks. tick := func() { time.Sleep(2 * time.Millisecond) } t.Run("SessionIDKeptWhenProvided", func(t *testing.T) { @@ -140,10 +143,10 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio } }) - t.Run("GetLatestSnapshotUpdateWins", func(t *testing.T) { - // Recency is judged by UpdatedAt, not creation order: rewriting a - // row (e.g. a detach finalize landing after other branches were - // written) moves it to the front. + t.Run("GetLatestSnapshotByCreatedAt", func(t *testing.T) { + // Recency is judged by CreatedAt: the newest-created leaf wins, and a + // later rewrite of an older row (e.g. a detach finalize) does not move + // it ahead, because the rewrite preserves CreatedAt. store := newStore(t) saveRow(t, store, "root", "sess-1", "", exp.SnapshotStatusCompleted) tick() @@ -151,11 +154,13 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio tick() saveRow(t, store, "b2", "sess-1", "root", exp.SnapshotStatusCompleted) tick() + // Finalize the older row b1; the copy preserves its CreatedAt. if _, err := store.SaveSnapshot(ctx, "b1", func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { rewritten := *existing rewritten.Status = exp.SnapshotStatusCompleted rewritten.State = &exp.SessionState[testState]{Custom: testState{Counter: 2}} + rewritten.UpdatedAt = time.Now() return &rewritten, nil }); err != nil { t.Fatalf("SaveSnapshot finalize: %v", err) @@ -165,8 +170,8 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio if err != nil { t.Fatalf("GetLatestSnapshot: %v", err) } - if latest == nil || latest.SnapshotID != "b1" { - t.Errorf("latest = %+v, want freshly finalized snapshot b1", latest) + if latest == nil || latest.SnapshotID != "b2" { + t.Errorf("latest = %+v, want newest-created snapshot b2 (finalize must not move b1 ahead)", latest) } }) @@ -228,3 +233,151 @@ func runSessionIDStoreTests(t *testing.T, newStore func(t *testing.T) exp.Sessio } }) } + +// abortViaSave flips a pending snapshot to aborted via SaveSnapshot, mirroring +// the agent runtime's abort (the store has no dedicated abort method). Returns +// the resulting status: aborted when it was pending, the existing terminal +// status when already settled, or "" when the snapshot does not exist. +func abortViaSave(t *testing.T, store exp.SessionStore[testState], id string) exp.SnapshotStatus { + t.Helper() + now := time.Now() + saved, err := store.SaveSnapshot(context.Background(), id, + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + if existing == nil { + return nil, nil + } + if existing.Status != exp.SnapshotStatusPending { + return existing, nil + } + updated := *existing + updated.Status = exp.SnapshotStatusAborted + updated.UpdatedAt = now + return &updated, nil + }) + if err != nil { + t.Fatalf("abortViaSave(%q): %v", id, err) + } + if saved == nil { + return "" + } + return saved.Status +} + +// runHeartbeatStoreTests exercises a heartbeat refresh - an ordinary +// SaveSnapshot that touches only HeartbeatAt on a still-pending row - against +// any store, so the in-memory and file stores stay behaviorally aligned. The +// central property: a heartbeat is a liveness signal, not a state change, so it +// advances HeartbeatAt but touches neither UpdatedAt nor the store's recency +// ordering. +func runHeartbeatStoreTests(t *testing.T, newStore func(t *testing.T) exp.SessionStore[testState]) { + ctx := context.Background() + tick := func() { time.Sleep(2 * time.Millisecond) } + + // beat refreshes a pending snapshot's heartbeat the way the agent runtime + // does: an ordinary SaveSnapshot carrying the existing row through unchanged + // but for HeartbeatAt (so caller-managed timestamps are preserved), touching + // only a still-pending row. + beat := func(t *testing.T, store exp.SessionStore[testState], id string) { + t.Helper() + now := time.Now() + if _, err := store.SaveSnapshot(ctx, id, + func(existing *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + if existing == nil || existing.Status != exp.SnapshotStatusPending { + return nil, nil + } + updated := *existing + updated.HeartbeatAt = &now + return &updated, nil + }); err != nil { + t.Fatalf("heartbeat SaveSnapshot: %v", err) + } + } + savePending := func(t *testing.T, store exp.SessionStore[testState], id, sessionID, parentID string) { + t.Helper() + now := time.Now() + if _, err := store.SaveSnapshot(ctx, id, + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: sessionID, ParentID: parentID, Status: exp.SnapshotStatusPending, CreatedAt: now, UpdatedAt: now}, nil + }); err != nil { + t.Fatalf("SaveSnapshot(%q): %v", id, err) + } + } + + t.Run("AdvancesHeartbeatNotUpdatedAt", func(t *testing.T) { + store := newStore(t) + savePending(t, store, "p", "sess", "") + before, err := store.GetSnapshot(ctx, "p") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + tick() + beat(t, store, "p") + after, err := store.GetSnapshot(ctx, "p") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if after.HeartbeatAt == nil { + t.Error("HeartbeatAt was not stamped") + } + if !after.UpdatedAt.Equal(before.UpdatedAt) { + t.Errorf("heartbeat advanced UpdatedAt: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) + } + if after.Status != exp.SnapshotStatusPending { + t.Errorf("status = %q, want pending", after.Status) + } + }) + + t.Run("NoopOnTerminal", func(t *testing.T) { + store := newStore(t) + now := time.Now() + if _, err := store.SaveSnapshot(ctx, "c", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: "sess", Status: exp.SnapshotStatusCompleted, CreatedAt: now, UpdatedAt: now}, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + before, err := store.GetSnapshot(ctx, "c") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + beat(t, store, "c") + after, err := store.GetSnapshot(ctx, "c") + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if after.HeartbeatAt != nil { + t.Errorf("beat stamped a heartbeat on a terminal row: %v", after.HeartbeatAt) + } + if !after.UpdatedAt.Equal(before.UpdatedAt) { + t.Errorf("beat bumped UpdatedAt on a terminal row: before=%v after=%v", before.UpdatedAt, after.UpdatedAt) + } + }) + + t.Run("DoesNotChangeRecency", func(t *testing.T) { + // A heartbeat must not move a pending row ahead of a newer row in + // GetLatestSnapshot: recency is by CreatedAt, and a beat must not touch + // it (nor anything else but HeartbeatAt). + store := newStore(t) + savePending(t, store, "old", "sess", "") + tick() + now := time.Now() + if _, err := store.SaveSnapshot(ctx, "new", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: "sess", ParentID: "old", Status: exp.SnapshotStatusCompleted, State: &exp.SessionState[testState]{Custom: testState{Counter: 9}}, CreatedAt: now, UpdatedAt: now}, nil + }); err != nil { + t.Fatalf("SaveSnapshot(new): %v", err) + } + tick() + for i := 0; i < 3; i++ { + beat(t, store, "old") + tick() + } + latest, err := store.GetLatestSnapshot(ctx, "sess") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "new" { + t.Errorf("latest = %+v, want \"new\" (a heartbeat must not affect recency)", latest) + } + }) +} diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 617f856221..e620281622 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -72,7 +72,7 @@ func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { // WithSessionStore sets the store for persisting snapshots. The store must // implement [SnapshotReader] and [SnapshotWriter] at minimum. Detach -// support also requires [SnapshotAborter]; detach attempts on a store +// support also requires [SnapshotSubscriber]; detach attempts on a store // that lacks that interface are rejected at runtime. func WithSessionStore[State any](store SessionStore[State]) AgentOption[State] { return &agentOptions[State]{store: store} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 35e5a248bb..b2311e3823 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -47,19 +47,20 @@ type SnapshotReader[State any] interface { // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) - // GetLatestSnapshot returns the session's most recently updated + // GetLatestSnapshot returns the session's most recently created // snapshot, whatever its status: a pending, failed, or aborted row is // returned like any other, and the caller applies its own policy. // Returns nil if the session has no rows, and an error if sessionID is // empty. // - // "Most recently updated" means the greatest [SessionSnapshot.UpdatedAt], - // falling back to CreatedAt on rows that lack one; break ties - // deterministically (e.g. by SnapshotID). This is a plain max-timestamp - // lookup, implementable as a single indexed query (e.g. WHERE sessionId = ? - // ORDER BY updatedAt DESC LIMIT 1). ParentID is informational lineage and - // plays no part in resolution: when history forks, the most recently - // updated branch wins. + // "Most recently created" means the greatest [SessionSnapshot.CreatedAt]; + // break ties deterministically (e.g. by SnapshotID). This is a plain + // max-timestamp lookup, implementable as a single indexed query (e.g. + // WHERE sessionId = ? ORDER BY createdAt DESC LIMIT 1). A later rewrite of + // an older row (e.g. a detach finalize) does not move it ahead of a + // newer-created sibling, since CreatedAt is preserved across rewrites. + // ParentID is informational lineage and plays no part in resolution: when + // history forks, the most recently created branch wins. GetLatestSnapshot(ctx context.Context, sessionID string) (*SessionSnapshot[State], error) } @@ -67,8 +68,8 @@ type SnapshotReader[State any] interface { // implement to be used with [WithSessionStore]. type SnapshotWriter[State any] interface { // SaveSnapshot atomically reads the snapshot at id (if any), applies - // fn, and persists the result. The store owns identity and - // lifecycle-timestamp fields: + // fn, and persists the result largely verbatim. The store owns only + // identity; the caller (fn) owns the lifecycle timestamps and status: // // - SnapshotID: if id is empty, the store generates a fresh ID; // otherwise the store uses id (any SnapshotID populated by fn is @@ -77,9 +78,14 @@ type SnapshotWriter[State any] interface { // belongs to: preserved from the existing row on update (a row's // session never changes once set), otherwise taken from fn's row // as-is. Stores never mint or infer session IDs. - // - CreatedAt: stamped to the wall clock on first write; preserved - // from the existing row on update. - // - UpdatedAt: stamped to the wall clock on every commit. + // - CreatedAt / UpdatedAt: caller-managed. The store persists whatever fn + // returns and never stamps them. fn sets CreatedAt and UpdatedAt to the + // current time on a new row, preserves CreatedAt and advances UpdatedAt + // on a state-changing rewrite, and preserves both on a non-state write + // (e.g. a heartbeat refresh, which carries the existing snapshot through + // unchanged but for HeartbeatAt). Keeping timestamps with the caller is + // what lets a heartbeat advance liveness without registering as a state + // change - the store has no special heartbeat path. // - Status: if the snapshot returned by fn has Status="", it is // defaulted to [SnapshotStatusCompleted] (the common case for // synchronous turn-end writes). Callers writing a pending row must @@ -103,25 +109,17 @@ type SnapshotWriter[State any] interface { ) (*SessionSnapshot[State], error) } -// SnapshotAborter is the optional capability layered on [SessionStore] that -// lets an agent's invocations be aborted. The two methods work together: -// [SnapshotAborter.AbortSnapshot] flips a pending snapshot's status to aborted, -// and [SnapshotAborter.OnSnapshotStatusChange] lets the agent runtime observe -// the flip without polling so it can promptly cancel the work context. A store -// must implement both or neither. -type SnapshotAborter interface { - // AbortSnapshot atomically transitions a snapshot from - // [SnapshotStatusPending] to [SnapshotStatusAborted] and returns the - // resulting status. If the snapshot is in any other status the - // operation is a no-op and the existing status is returned. Returns - // an empty status with a nil error if the snapshot is not found, so - // callers can distinguish "not found" from a real error. - // - // Implementations must perform the read-and-write atomically (e.g., a - // transaction or a compare-and-swap) so a racing terminal write cannot - // clobber the pending row. - AbortSnapshot(ctx context.Context, snapshotID string) (SnapshotStatus, error) - +// SnapshotSubscriber is the optional capability layered on [SessionStore] that +// lets the agent runtime observe a snapshot's status changes without polling. +// It is what makes a detached invocation abortable: aborting is an ordinary +// [SnapshotWriter.SaveSnapshot] that flips a pending row to +// [SnapshotStatusAborted], and the runtime reacts to that flip through this +// subscription, promptly cancelling the background work context. +// +// A store that does not implement it cannot support detach (there is no way to +// signal the background work to stop); see the runtime's detach precondition +// check. +type SnapshotSubscriber interface { // OnSnapshotStatusChange returns a channel that yields the snapshot's // status whenever it changes. The first value (if any) reflects the // status at subscription time. The channel is closed when ctx is @@ -134,10 +132,10 @@ type SnapshotAborter interface { } // SessionStore is the minimum store interface required by -// [WithSessionStore]. The abort lifecycle is layered as the optional -// [SnapshotAborter] capability and checked at runtime: a store wired +// [WithSessionStore]. Status-change observation is layered as the optional +// [SnapshotSubscriber] capability and checked at runtime: a store wired // into an agent that intends to support detach must also implement -// [SnapshotAborter], or the runtime will reject detach attempts. +// [SnapshotSubscriber], or the runtime will reject detach attempts. type SessionStore[State any] interface { SnapshotReader[State] SnapshotWriter[State] @@ -185,9 +183,8 @@ func cloneArtifacts(arts []*Artifact) []*Artifact { // and non-Go clients. Local Go callers use the store reference directly. // // - The agent's name under [api.ActionTypeAgentAbort] — abortSnapshot, -// created only when the store also implements [SnapshotAborter] (which -// bundles both the abort trigger and the status-change subscription -// needed for the runtime to react). +// created only when the store also implements [SnapshotSubscriber], so the +// runtime can react to the abort it writes via SaveSnapshot. // // When the agent is client-managed (no store configured), neither action // is created: there is no server-side snapshot to fetch or abort. @@ -249,6 +246,15 @@ func newSnapshotActions[State any]( // A failed snapshot's state is its last-good state, so it is // returned like any other. resp := *snap + // Surface a pending snapshot whose heartbeat has gone stale as + // expired: its detached background worker is presumed dead, so + // report the orphan rather than leaving it pending forever. This is + // computed on read only, never written back to the store, so the + // raw row stays pending. Checked before the empty-status default + // below, which applies only to a row carrying no status at all. + if isHeartbeatExpired(snap, defaultHeartbeatTimeout) { + resp.Status = SnapshotStatusExpired + } if resp.Status == "" { resp.Status = SnapshotStatusCompleted } @@ -270,10 +276,9 @@ func newSnapshotActions[State any]( return &resp, nil }) - aborter, ok := store.(SnapshotAborter) - if !ok { - // Store doesn't support the abort lifecycle. Don't surface the - // action. + if _, ok := store.(SnapshotSubscriber); !ok { + // Without a subscriber the runtime cannot react to an abort, so the + // abort lifecycle is unsupported; don't surface the action. return getSnapshotAction, nil } abortSnapshotAction := core.NewAction(agentName, api.ActionTypeAgentAbort, nil, nil, @@ -281,7 +286,9 @@ func newSnapshotActions[State any]( if req == nil || req.SnapshotID == "" { return nil, core.NewError(core.INVALID_ARGUMENT, "abortSnapshot: snapshotId is required") } - status, err := aborter.AbortSnapshot(ctx, req.SnapshotID) + // Aborting is an ordinary SaveSnapshot that flips a pending row to + // aborted; the store has no dedicated abort method. + status, err := abortPendingSnapshot(ctx, store, req.SnapshotID) if err != nil { return nil, core.NewError(core.INTERNAL, "abortSnapshot: %v", err) } diff --git a/go/ai/exp/teststore_test.go b/go/ai/exp/teststore_test.go index 969a169953..fcc4824784 100644 --- a/go/ai/exp/teststore_test.go +++ b/go/ai/exp/teststore_test.go @@ -30,7 +30,6 @@ import ( "fmt" "slices" "sync" - "time" "github.com/google/uuid" ) @@ -79,8 +78,8 @@ func (s *testInMemStore[State]) GetLatestSnapshot(_ context.Context, sessionID s if snap.SessionID != sessionID { continue } - if latest == nil || snap.UpdatedAt.After(latest.UpdatedAt) || - (snap.UpdatedAt.Equal(latest.UpdatedAt) && snap.SnapshotID > latest.SnapshotID) { + if latest == nil || snap.CreatedAt.After(latest.CreatedAt) || + (snap.CreatedAt.Equal(latest.CreatedAt) && snap.SnapshotID > latest.SnapshotID) { latest = snap } } @@ -90,21 +89,6 @@ func (s *testInMemStore[State]) GetLatestSnapshot(_ context.Context, sessionID s return testCopySnapshot(latest) } -func (s *testInMemStore[State]) AbortSnapshot(_ context.Context, snapshotID string) (SnapshotStatus, error) { - s.mu.Lock() - defer s.mu.Unlock() - snap, ok := s.snapshots[snapshotID] - if !ok { - return "", nil - } - if snap.Status == SnapshotStatusPending { - snap.Status = SnapshotStatusAborted - snap.UpdatedAt = time.Now() - s.notifyLocked(snapshotID, snap.Status) - } - return snap.Status, nil -} - func (s *testInMemStore[State]) SaveSnapshot( _ context.Context, id string, @@ -135,16 +119,11 @@ func (s *testInMemStore[State]) SaveSnapshot( } next.SnapshotID = id - now := time.Now() - if existing != nil { - next.CreatedAt = existing.CreatedAt - if existing.SessionID != "" { - next.SessionID = existing.SessionID // a row's session never changes - } - } else { - next.CreatedAt = now + // SessionID is preserved (a row's session never changes); CreatedAt, + // UpdatedAt, and HeartbeatAt are caller-managed and persisted verbatim. + if existing != nil && existing.SessionID != "" { + next.SessionID = existing.SessionID } - next.UpdatedAt = now if next.Status == "" { next.Status = SnapshotStatusCompleted } diff --git a/go/core/schemas.config b/go/core/schemas.config index d6045ac8da..fcfd3fc19c 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1619,9 +1619,21 @@ CreatedAt is when the snapshot was created. SessionSnapshot.updatedAt type time.Time SessionSnapshot.updatedAt doc -UpdatedAt is when the snapshot was last written. For pending snapshots -it equals CreatedAt; once the snapshot is finalized it reflects the -terminal write. +UpdatedAt is when the snapshot's state was last written. A heartbeat refresh on +a pending snapshot deliberately does not advance it, so a pending snapshot's +UpdatedAt equals CreatedAt until the snapshot is finalized, whose terminal write +advances it. +. + +SessionSnapshot.heartbeatAt type *time.Time +SessionSnapshot.heartbeatAt doc +HeartbeatAt is refreshed periodically while a detached (background) turn is +in flight, so a reader can detect a dead background worker: if a pending +snapshot's heartbeat goes stale (older than the configured timeout), reads +surface its status as [SnapshotStatusExpired] (the dead worker can no longer +persist a terminal status itself). Refreshing it does not advance +[SessionSnapshot.UpdatedAt], so liveness stays distinct from state changes. Nil +on snapshots that never detached. . SessionSnapshot.status doc @@ -1671,6 +1683,10 @@ with the cumulative final state and [SnapshotStatusCompleted] / [SnapshotStatusFailed] when the agent finishes, or with [SnapshotStatusAborted] if the client called abortSnapshot in the meantime. + +[SnapshotStatusExpired] is never persisted: it is computed on read for a +pending snapshot whose background worker is presumed dead (its heartbeat +went stale), surfacing the orphan rather than leaving it pending forever. . SnapshotStatusPending doc @@ -1694,6 +1710,13 @@ The snapshot's Error field describes the failure and resume is rejected with that same error. . +SnapshotStatusExpired doc +SnapshotStatusExpired indicates a pending snapshot whose detached background +worker is presumed dead: its [SessionSnapshot.HeartbeatAt] went stale. It is +computed on read (never persisted), so the raw store row stays +[SnapshotStatusPending] while a read surfaces it as expired. +. + # ---------------------------------------------------------------------------- # AgentFinishReason # ---------------------------------------------------------------------------- @@ -1831,7 +1854,7 @@ StateManagement reports who owns session state. AgentMetadata.abortable doc Abortable reports whether the agent's invocations can be aborted -(true when the store implements [SnapshotAborter]). +(true when the store implements [SnapshotSubscriber]). . AgentMetadata.stateSchema type map[string]any diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index 2fe76ab9db..051e1ed437 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -459,18 +459,18 @@ func summarizeLatest(ctx context.Context, a *aix.Agent[any], sessionID string) s // status first, so a snapshot that finalized just before the // subscription is observed immediately. // -// The status subscription is the optional SnapshotAborter half of the -// store contract. A store without it cannot stream background progress, +// The status subscription is the optional SnapshotSubscriber capability of +// the store contract. A store without it cannot stream background progress, // so we fall back to reading the snapshot once and returning it as-is. func waitForFinalize(ctx context.Context, a *aix.Agent[any], snapshotID string) (*aix.SessionSnapshot[any], error) { store := a.Store() - aborter, ok := store.(aix.SnapshotAborter) + subscriber, ok := store.(aix.SnapshotSubscriber) if !ok { return store.GetSnapshot(ctx, snapshotID) } subCtx, cancel := context.WithCancel(ctx) defer cancel() - statusCh := aborter.OnSnapshotStatusChange(subCtx, snapshotID) + statusCh := subscriber.OnSnapshotStatusChange(subCtx, snapshotID) for { select { case <-ctx.Done(): diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index d988fc55ef..fa47e13c7a 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -73,6 +73,7 @@ class SnapshotStatus(StrEnum): COMPLETED = 'completed' ABORTED = 'aborted' FAILED = 'failed' + EXPIRED = 'expired' class EvalStatusEnum(StrEnum): @@ -229,6 +230,7 @@ class SessionSnapshot(GenkitModel): parent_id: str | None = None created_at: str = Field(...) updated_at: str | None = None + heartbeat_at: str | None = None status: SnapshotStatus | None = None finish_reason: AgentFinishReason | None = None error: GenkitRuntimeError | None = None From 2a48ce2f41e4e4138a5fc687878254e20e6cf010 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 18 Jun 2026 21:18:25 -0700 Subject: [PATCH 121/141] refactor(go/core): rename bidi connection methods for clarity StreamBidi/StreamBidiJSON only open a connection; nothing streams until the caller drives the returned handle. Rename them to Connect/ConnectJSON. Also rename the one-shot RunWithInit to RunBidi so it pairs with RunBidiJSON, and BidiSessionOptions to BidiJSONOptions since it configures the JSON call surface (one-shot and streaming), not only sessions. --- go/core/api/action.go | 18 +++++----- go/core/bidi.go | 18 +++++----- go/core/bidi_test.go | 72 +++++++++++++++++++------------------- go/genkit/reflection.go | 2 +- go/genkit/reflection_v2.go | 2 +- 5 files changed, 57 insertions(+), 55 deletions(-) diff --git a/go/core/api/action.go b/go/core/api/action.go index 5818108a30..0bec4487a0 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -41,13 +41,15 @@ type Action interface { Desc() ActionDesc } -// BidiSessionOptions configures a bidirectional session started through the -// JSON interfaces. A nil value is equivalent to zero options. The struct may -// gain fields over time; construct it by field name. +// BidiJSONOptions carries the options for a JSON-encoded call to a bidi +// action, used by both the one-shot [BidiAction.RunBidiJSON] and the streaming +// [BidiAction.ConnectJSON]. It is the JSON counterpart of the typed init +// argument. A nil value is equivalent to zero options. The struct may gain +// fields over time; construct it by field name. // // Experimental: bidirectional streaming is experimental and subject to change. -type BidiSessionOptions struct { - // Init is the JSON-encoded initial configuration for the session, +type BidiJSONOptions struct { + // Init is the JSON-encoded initial configuration for the call, // decoded into the action's Init type and validated against its // InitSchema. Empty or JSON-null means no init (the zero Init value). Init json.RawMessage @@ -64,10 +66,10 @@ type BidiAction interface { // RunBidiJSON runs the bidi action as a single one-shot call: input is // delivered as the only chunk on the input stream, outgoing chunks are // forwarded to cb, and opts carries the session init. - RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *BidiSessionOptions) (*ActionRunResult[json.RawMessage], error) - // StreamBidiJSON starts a bidirectional streaming session using + RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *BidiJSONOptions) (*ActionRunResult[json.RawMessage], error) + // ConnectJSON starts a bidirectional streaming session using // JSON-encoded messages. - StreamBidiJSON(ctx context.Context, opts *BidiSessionOptions) (BidiJSONConnection, error) + ConnectJSON(ctx context.Context, opts *BidiJSONOptions) (BidiJSONConnection, error) } // BidiJSONConnection is a JSON-encoded view of an active bidirectional diff --git a/go/core/bidi.go b/go/core/bidi.go index cad8fc3cc9..469e0743f7 100644 --- a/go/core/bidi.go +++ b/go/core/bidi.go @@ -54,7 +54,7 @@ type BidiFunc[In, Out, Stream, Init any] = func(ctx context.Context, init Init, // // BidiAction embeds [Action], so it can also be invoked through the regular // unary surface (Run, RunJSON): the input is delivered as a single chunk on -// the input stream with the zero Init value. Use [BidiAction.RunWithInit] or +// the input stream with the zero Init value. Use [BidiAction.RunBidi] or // [BidiAction.RunBidiJSON] for one-shot calls that supply init. // // For internal use only. @@ -200,13 +200,13 @@ func (b *BidiAction[In, Out, Stream, Init]) spanInitValue(init Init) any { return init } -// RunWithInit executes the bidi action as a single one-shot call with the +// RunBidi executes the bidi action as a single one-shot call with the // given initial configuration: input is delivered as the only chunk on the // input stream and outgoing chunks are forwarded to cb. Returns an error if // init fails validation against the action's InitSchema. // // Experimental: bidirectional streaming is experimental and subject to change. -func (b *BidiAction[In, Out, Stream, Init]) RunWithInit(ctx context.Context, init Init, input In, cb StreamCallback[Stream]) (Out, error) { +func (b *BidiAction[In, Out, Stream, Init]) RunBidi(ctx context.Context, init Init, input In, cb StreamCallback[Stream]) (Out, error) { r, err := b.Action.runWithTelemetry(ctx, input, cb, b.oneShotFn(init), b.spanInitValue(init)) if err != nil { return base.Zero[Out](), err @@ -220,7 +220,7 @@ func (b *BidiAction[In, Out, Stream, Init]) RunWithInit(ctx context.Context, ini // init fails to decode or validate. // // Experimental: bidirectional streaming is experimental and subject to change. -func (b *BidiAction[In, Out, Stream, Init]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], opts *api.BidiSessionOptions) (*api.ActionRunResult[json.RawMessage], error) { +func (b *BidiAction[In, Out, Stream, Init]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], opts *api.BidiJSONOptions) (*api.ActionRunResult[json.RawMessage], error) { init, hasInit, err := b.decodeInit(opts) if err != nil { return nil, err @@ -232,26 +232,26 @@ func (b *BidiAction[In, Out, Stream, Init]) RunBidiJSON(ctx context.Context, inp return b.Action.runJSONWithTelemetry(ctx, input, cb, b.oneShotFn(init), spanInit) } -// StreamBidi starts a bidirectional streaming connection with the given +// Connect starts a bidirectional streaming connection with the given // initial configuration. For actions whose Init type is struct{} (no init), // pass struct{}{}. Returns an error if init fails validation against the // action's InitSchema. // A trace span is created that remains open for the lifetime of the connection. // // Experimental: bidirectional streaming is experimental and subject to change. -func (b *BidiAction[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) { +func (b *BidiAction[In, Out, Stream, Init]) Connect(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) { if err := b.validateInit(init); err != nil { return nil, err } return b.startBidi(ctx, init, b.spanInitValue(init)), nil } -// StreamBidiJSON starts a bidirectional streaming session using JSON-encoded +// ConnectJSON starts a bidirectional streaming session using JSON-encoded // messages. Returns an error if the init carried by opts fails to decode or // validate. // // Experimental: bidirectional streaming is experimental and subject to change. -func (b *BidiAction[In, Out, Stream, Init]) StreamBidiJSON(ctx context.Context, opts *api.BidiSessionOptions) (api.BidiJSONConnection, error) { +func (b *BidiAction[In, Out, Stream, Init]) ConnectJSON(ctx context.Context, opts *api.BidiJSONOptions) (api.BidiJSONConnection, error) { init, hasInit, err := b.decodeInit(opts) if err != nil { return nil, err @@ -288,7 +288,7 @@ func (b *BidiAction[In, Out, Stream, Init]) StreamBidiJSON(ctx context.Context, // type. Returns hasInit=false when opts is nil or the payload is empty or // JSON null, so transports can pass the request's init field through // unconditionally. -func (b *BidiAction[In, Out, Stream, Init]) decodeInit(opts *api.BidiSessionOptions) (Init, bool, error) { +func (b *BidiAction[In, Out, Stream, Init]) decodeInit(opts *api.BidiJSONOptions) (Init, bool, error) { var init Init if opts == nil || !base.HasJSONValue(opts.Init) { return init, false, nil diff --git a/go/core/bidi_test.go b/go/core/bidi_test.go index 91fea75ebf..d04bd585dd 100644 --- a/go/core/bidi_test.go +++ b/go/core/bidi_test.go @@ -47,7 +47,7 @@ func TestBidiActionEcho(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -104,7 +104,7 @@ func TestBidiActionWithConfig(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, Config{Prefix: "INFO"}) + conn, err := action.Connect(ctx, Config{Prefix: "INFO"}) if err != nil { t.Fatal(err) } @@ -127,9 +127,9 @@ func TestBidiActionWithConfig(t *testing.T) { } } -// TestRunWithInit verifies the typed one-shot path: input is delivered as a +// TestRunBidi verifies the typed one-shot path: input is delivered as a // single chunk and init configures the session. -func TestRunWithInit(t *testing.T) { +func TestRunBidi(t *testing.T) { ctx := context.Background() type Config struct{ Prefix string } @@ -145,9 +145,9 @@ func TestRunWithInit(t *testing.T) { }, ) - got, err := action.RunWithInit(ctx, Config{Prefix: ">> "}, "hello", nil) + got, err := action.RunBidi(ctx, Config{Prefix: ">> "}, "hello", nil) if err != nil { - t.Fatalf("RunWithInit: %v", err) + t.Fatalf("RunBidi: %v", err) } if got != ">> hello" { t.Errorf("output = %q, want %q", got, ">> hello") @@ -220,7 +220,7 @@ func TestRunBidiJSON(t *testing.T) { } r, err := action.RunBidiJSON(ctx, json.RawMessage(`"hello"`), cb, - &api.BidiSessionOptions{Init: json.RawMessage(`{"prefix":">> "}`)}) + &api.BidiJSONOptions{Init: json.RawMessage(`{"prefix":">> "}`)}) if err != nil { t.Fatalf("RunBidiJSON: %v", err) } @@ -255,7 +255,7 @@ func TestRunBidiJSONInvalidInit(t *testing.T) { ) _, err := action.RunBidiJSON(ctx, json.RawMessage(`"in"`), nil, - &api.BidiSessionOptions{Init: json.RawMessage(`{not json`)}) + &api.BidiJSONOptions{Init: json.RawMessage(`{not json`)}) if err == nil { t.Fatal("expected error for invalid JSON, got nil") } @@ -268,16 +268,16 @@ func TestRunBidiJSONInvalidInit(t *testing.T) { } } -// TestStreamBidiJSONNullInit verifies that nil options and a JSON-null init +// TestConnectJSONNullInit verifies that nil options and a JSON-null init // payload are both treated as no init (the zero Init value). -func TestStreamBidiJSONNullInit(t *testing.T) { +func TestConnectJSONNullInit(t *testing.T) { ctx := context.Background() type Config struct { Prefix string `json:"prefix"` } - for _, opts := range []*api.BidiSessionOptions{nil, {Init: json.RawMessage(`null`)}} { + for _, opts := range []*api.BidiJSONOptions{nil, {Init: json.RawMessage(`null`)}} { var sawInit Config action := NewBidiAction( "null-init", api.ActionTypeCustom, nil, @@ -289,9 +289,9 @@ func TestStreamBidiJSONNullInit(t *testing.T) { }, ) - conn, err := action.StreamBidiJSON(ctx, opts) + conn, err := action.ConnectJSON(ctx, opts) if err != nil { - t.Fatalf("StreamBidiJSON(%v): %v", opts, err) + t.Fatalf("ConnectJSON(%v): %v", opts, err) } conn.Close() if _, err := conn.Output(); err != nil { @@ -326,7 +326,7 @@ func TestInitSchemaValidationRejectsBadInit(t *testing.T) { ) // Missing required "prefix" field. - _, err := action.StreamBidi(ctx, map[string]any{"other": 1}) + _, err := action.Connect(ctx, map[string]any{"other": 1}) if err == nil { t.Fatal("expected validation error, got nil") } @@ -362,9 +362,9 @@ func TestInitSchemaValidationAcceptsGoodInit(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, map[string]any{"prefix": ">> "}) + conn, err := action.Connect(ctx, map[string]any{"prefix": ">> "}) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } conn.Close() out, err := conn.Output() @@ -388,7 +388,7 @@ func TestBidiConnectionSendAfterClose(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -413,7 +413,7 @@ func TestBidiConnectionContextCancellation(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -468,7 +468,7 @@ func TestBidiActionDone(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -535,7 +535,7 @@ func TestBidiActionPanicRecovered(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -562,7 +562,7 @@ func TestBidiActionClosingOutChIsError(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -597,7 +597,7 @@ func TestBidiReceiveBreakDoesNotCancelSession(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -640,7 +640,7 @@ func TestBidiConnectionCancel(t *testing.T) { }, ) - conn, err := action.StreamBidi(context.Background(), struct{}{}) + conn, err := action.Connect(context.Background(), struct{}{}) if err != nil { t.Fatal(err) } @@ -667,7 +667,7 @@ func TestBidiOutputAfterCompletionNotCancelled(t *testing.T) { return "done", nil }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -713,7 +713,7 @@ func TestBidiJSONConnSendValidatesChunks(t *testing.T) { ) t.Run("valid chunk delivered", func(t *testing.T) { - conn, err := action.StreamBidiJSON(ctx, nil) + conn, err := action.ConnectJSON(ctx, nil) if err != nil { t.Fatal(err) } @@ -733,7 +733,7 @@ func TestBidiJSONConnSendValidatesChunks(t *testing.T) { }) t.Run("invalid chunk fails the session", func(t *testing.T) { - conn, err := action.StreamBidiJSON(ctx, nil) + conn, err := action.ConnectJSON(ctx, nil) if err != nil { t.Fatal(err) } @@ -751,7 +751,7 @@ func TestBidiJSONConnSendValidatesChunks(t *testing.T) { }) t.Run("null chunk validated like any payload", func(t *testing.T) { - conn, err := action.StreamBidiJSON(ctx, nil) + conn, err := action.ConnectJSON(ctx, nil) if err != nil { t.Fatal(err) } @@ -777,7 +777,7 @@ func TestBidiOutputSchemaValidatedOnConnection(t *testing.T) { }, ) - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Fatal(err) } @@ -808,7 +808,7 @@ func TestBidiJSONConnReceiveMarshalErrorAbortsSession(t *testing.T) { }, ) - conn, err := action.StreamBidiJSON(ctx, nil) + conn, err := action.ConnectJSON(ctx, nil) if err != nil { t.Fatal(err) } @@ -849,7 +849,7 @@ func TestBidiInvalidChunkFailsCtxObliviousSession(t *testing.T) { }, ) - conn, err := action.StreamBidiJSON(ctx, nil) + conn, err := action.ConnectJSON(ctx, nil) if err != nil { t.Fatal(err) } @@ -876,7 +876,7 @@ func TestBidiSendAfterCompletionFails(t *testing.T) { }, ) - conn, err := action.StreamBidi(context.Background(), struct{}{}) + conn, err := action.Connect(context.Background(), struct{}{}) if err != nil { t.Fatal(err) } @@ -912,7 +912,7 @@ func TestBidiSessionWrapperPanicNotMislabeled(t *testing.T) { }, ) - conn, err := action.StreamBidi(context.Background(), struct{}{}) + conn, err := action.Connect(context.Background(), struct{}{}) if err != nil { t.Fatal(err) } @@ -1000,7 +1000,7 @@ func TestBidiJSONConnEmptyChunkValidated(t *testing.T) { }, ) - conn, err := action.StreamBidiJSON(ctx, nil) + conn, err := action.ConnectJSON(ctx, nil) if err != nil { t.Fatal(err) } @@ -1080,7 +1080,7 @@ func TestBidiEchoStress(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - conn, err := action.StreamBidi(ctx, struct{}{}) + conn, err := action.Connect(ctx, struct{}{}) if err != nil { t.Error(err) return @@ -1140,9 +1140,9 @@ func TestResolveBidiActionFor(t *testing.T) { if resolved == nil { t.Fatal("ResolveBidiActionFor returned nil") } - got, err := resolved.RunWithInit(ctx, Config{Prefix: ">> "}, "hello", nil) + got, err := resolved.RunBidi(ctx, Config{Prefix: ">> "}, "hello", nil) if err != nil { - t.Fatalf("RunWithInit: %v", err) + t.Fatalf("RunBidi: %v", err) } if got != ">> hello" { t.Errorf("output = %q, want %q", got, ">> hello") diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index b28e0be48e..bbdf15043a 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -776,7 +776,7 @@ func runActionWithOptionalInit(ctx context.Context, a api.Action, input, init js return nil, err } if bidi, ok := a.(api.BidiAction); ok && base.HasJSONValue(init) { - return bidi.RunBidiJSON(ctx, input, cb, &api.BidiSessionOptions{Init: init}) + return bidi.RunBidiJSON(ctx, input, cb, &api.BidiJSONOptions{Init: init}) } return a.RunJSONWithTelemetry(ctx, input, cb) } diff --git a/go/genkit/reflection_v2.go b/go/genkit/reflection_v2.go index 158dc47134..fa2e233495 100644 --- a/go/genkit/reflection_v2.go +++ b/go/genkit/reflection_v2.go @@ -545,7 +545,7 @@ func (s *reflectionServerV2) handleRunActionBidi(ctx context.Context, req *jsonR actionCtx = core.WithActionContext(actionCtx, contextMap) } - conn, err := bidi.StreamBidiJSON(actionCtx, &api.BidiSessionOptions{Init: params.Init}) + conn, err := bidi.ConnectJSON(actionCtx, &api.BidiJSONOptions{Init: params.Init}) if err != nil { s.sendErrorResponse(req.ID, jsonRPCServerError, err.Error(), nil) return From 6a59525cdf9a366c3a281a9dfe56425139abe863 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 18 Jun 2026 21:23:08 -0700 Subject: [PATCH 122/141] refactor(go/exp): rename agent StreamBidi to Connect Follow the core rename (StreamBidi/StreamBidiJSON -> Connect/ConnectJSON, BidiSessionOptions -> BidiJSONOptions) through the Agent surface, its tests, and the basic-agents sample. Agent.RunBidiJSON keeps its name; only its options type changed. --- go/ai/exp/agent.go | 24 +-- go/ai/exp/agent_test.go | 252 ++++++++++++++++---------------- go/ai/exp/custompatch_test.go | 24 +-- go/ai/exp/option.go | 2 +- go/samples/basic-agents/cli.go | 2 +- go/samples/basic-agents/main.go | 2 +- 6 files changed, 153 insertions(+), 153 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 922442b845..7c380438cd 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -411,7 +411,7 @@ type AgentFunc[State any] = func(ctx context.Context, resp Responder, sess *Sess // // Agent implements [api.BidiAction], so generic transports accept it directly // (e.g. pass it to genkit.Handler to serve it over HTTP, one turn per request). -// [Agent.Run], [Agent.RunText], and [Agent.StreamBidi] are typed conveniences +// [Agent.Run], [Agent.RunText], and [Agent.Connect] are typed conveniences // over the same underlying action. // // Server-managed agents (those with a [SessionStore] configured) also @@ -533,15 +533,15 @@ func (a *Agent[State]) RunJSONWithTelemetry(ctx context.Context, input json.RawM // counterpart of the [InvocationOption] values) rides in opts: input is // delivered as the only chunk on the input stream and outgoing chunks are // forwarded to cb. -func (a *Agent[State]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *api.BidiSessionOptions) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Agent[State]) RunBidiJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, opts *api.BidiJSONOptions) (*api.ActionRunResult[json.RawMessage], error) { return a.action.RunBidiJSON(ctx, input, cb, opts) } -// StreamBidiJSON starts a bidirectional streaming session using +// ConnectJSON starts a bidirectional streaming session using // JSON-encoded messages. Local Go callers should prefer the typed -// [Agent.StreamBidi]. -func (a *Agent[State]) StreamBidiJSON(ctx context.Context, opts *api.BidiSessionOptions) (api.BidiJSONConnection, error) { - return a.action.StreamBidiJSON(ctx, opts) +// [Agent.Connect]. +func (a *Agent[State]) ConnectJSON(ctx context.Context, opts *api.BidiJSONOptions) (api.BidiJSONConnection, error) { + return a.action.ConnectJSON(ctx, opts) } // DefineAgent defines a prompt-backed agent and registers it. Each turn @@ -2164,10 +2164,10 @@ func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) Ag // --- Agent client API --- -// StreamBidi starts a new agent invocation with bidirectional streaming. +// Connect starts a new agent invocation with bidirectional streaming. // Use this for multi-turn interactions where you need to send multiple inputs // and receive streaming chunks. For single-turn usage, see Run and RunText. -func (a *Agent[State]) StreamBidi( +func (a *Agent[State]) Connect( ctx context.Context, opts ...InvocationOption[State], ) (*AgentConnection[State], error) { @@ -2175,7 +2175,7 @@ func (a *Agent[State]) StreamBidi( if err != nil { return nil, err } - conn, err := a.action.StreamBidi(ctx, init) + conn, err := a.action.Connect(ctx, init) if err != nil { return nil, err } @@ -2184,7 +2184,7 @@ func (a *Agent[State]) StreamBidi( // Run starts a single-turn agent invocation with the given input. // It sends the input, waits for the agent to complete, and returns the output. -// For multi-turn interactions or streaming, use StreamBidi instead. +// For multi-turn interactions or streaming, use Connect instead. // // In-band failures (e.g. a failed turn) resolve as a failed [AgentOutput] // rather than an error; a rejected init payload fails with an error, since @@ -2194,7 +2194,7 @@ func (a *Agent[State]) Run( input *AgentInput, opts ...InvocationOption[State], ) (*AgentOutput[State], error) { - conn, err := a.StreamBidi(ctx, opts...) + conn, err := a.Connect(ctx, opts...) if err != nil { return nil, err } @@ -2324,7 +2324,7 @@ func (c *AgentConnection[State]) Close() error { // of the iterator does not cancel the connection; multi-turn callers // routinely break on [TurnEnd], send the next input, then call Receive // again to consume the next batch. Call [AgentConnection.Output] to -// finish the invocation, or cancel the ctx passed to StreamBidi to +// finish the invocation, or cancel the ctx passed to Connect to // abort it. // // Each yielded chunk's [AgentStreamChunk.CustomPatch] is applied to the diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index b5d8108b5a..3bfebe218d 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -99,9 +99,9 @@ func TestAgent_BasicMultiTurn(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } // Turn 1. @@ -168,9 +168,9 @@ func TestAgentConnection_Custom_TracksStreamedPatches(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } // Before any patch arrives, Custom() is the zero value. @@ -217,9 +217,9 @@ func TestAgent_WithSessionStore(t *testing.T) { af := defineCounterAgent(reg, "snapshotFlow", WithSessionStore(store)) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendText(t, conn, "turn1") @@ -266,9 +266,9 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { af := defineCounterAgent(reg, "resumeFlow", WithSessionStore(store)) // First invocation: create a snapshot. - conn1, err := af.StreamBidi(ctx) + conn1, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendText(t, conn1, "first message") for chunk, err := range conn1.Receive() { @@ -289,9 +289,9 @@ func TestAgent_ResumeFromSnapshot(t *testing.T) { } // Second invocation: resume from snapshot. - conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp1.SnapshotID)) + conn2, err := af.Connect(ctx, WithSnapshotID[testState](resp1.SnapshotID)) if err != nil { - t.Fatalf("StreamBidi with snapshot failed: %v", err) + t.Fatalf("Connect with snapshot failed: %v", err) } sendTurn(t, conn2, "continued message") conn2.Close() @@ -341,9 +341,9 @@ func TestAgent_ClientManagedState(t *testing.T) { Custom: testState{Counter: 5}, } - conn, err := af.StreamBidi(ctx, WithState(clientState)) + conn, err := af.Connect(ctx, WithState(clientState)) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendTurn(t, conn, "new message") @@ -403,9 +403,9 @@ func TestAgent_Artifacts(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendText(t, conn, "generate code") @@ -512,9 +512,9 @@ func TestAgent_InputMessageCloned(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sent := ai.NewUserTextMessage("original") @@ -638,9 +638,9 @@ func TestAgent_SendMessage(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } // Send a message via SendMessage. @@ -686,9 +686,9 @@ func TestAgent_SessionContext(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendTurn(t, conn, "test") @@ -712,9 +712,9 @@ func TestAgent_ErrorInTurn(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendText(t, conn, "trigger error") @@ -802,9 +802,9 @@ func TestAgent_FailedTurn_ClientManagedReturnsLastGoodState(t *testing.T) { af := defineLastGoodTestAgent(reg, "lastGoodClient") - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } for _, text := range []string{"one", "two", "boom"} { sendText(t, conn, text) @@ -899,9 +899,9 @@ func TestAgent_FailedTurn_ServerManagedReturnsLastTurnSnapshot(t *testing.T) { af := defineLastGoodTestAgent(reg, "recoveryDedup", WithSessionStore[testState](store)) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } turn0 := sendTurn(t, conn, "one") if turn0.SnapshotID == "" { @@ -989,9 +989,9 @@ func TestAgent_FailedTurn_EmitsFailedTurnEnd(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendText(t, conn, "hi") @@ -1040,9 +1040,9 @@ func TestAgent_CustomAgentContinuesAfterFailedTurn(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } for _, text := range []string{"one", "boom", "two"} { sendText(t, conn, text) @@ -1150,9 +1150,9 @@ func TestAgent_SetMessages(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendTurn(t, conn, "original") @@ -1203,9 +1203,9 @@ func TestAgent_TurnSpanOutput(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } // Two turns. @@ -1268,9 +1268,9 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendText(t, conn, "hello") @@ -1351,9 +1351,9 @@ func TestPromptAgent_Basic(t *testing.T) { af := DefineAgent[testState](reg, "testPrompt", FromPrompt()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } // Turn 1. @@ -1426,9 +1426,9 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { af := DefineAgent[testState](reg, "historyPrompt", FromPrompt()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } // Turn 1. @@ -1502,9 +1502,9 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendTurn(t, conn, "hello") @@ -1518,9 +1518,9 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { t.Fatal("expected snapshot ID") } - conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) + conn2, err := af.Connect(ctx, WithSnapshotID[testState](resp.SnapshotID)) if err != nil { - t.Fatalf("StreamBidi with snapshot failed: %v", err) + t.Fatalf("Connect with snapshot failed: %v", err) } sendTurn(t, conn2, "continued") @@ -1625,9 +1625,9 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { af := DefineAgent[testState](reg, "toolPrompt", FromPrompt()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendTurn(t, conn, "go") @@ -2014,9 +2014,9 @@ func TestPromptAgent_RejectsResumeForUnrequestedTool(t *testing.T) { af := DefineAgent[testState](reg, "plainPrompt", FromPrompt()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } // Turn 1: a plain text reply, so no tool request lands in history. @@ -2092,9 +2092,9 @@ func TestAgent_MultiTurnSnapshot(t *testing.T) { af := defineCounterAgent(reg, "multiDedupFlow", WithSessionStore(store)) // Multi-turn: one snapshot per turn; the output reuses the last one. - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } var snapshotIDs []string @@ -2199,9 +2199,9 @@ func TestAgent_FnPanicResolvesAsFailedOutput(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendText(t, conn, "trigger") @@ -2252,9 +2252,9 @@ func TestAgent_CancelDuringStreamReleasesGoroutine(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendText(t, conn, "go") @@ -2354,9 +2354,9 @@ func TestAgent_TurnEnd_CarriesSnapshotID(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } var observed []TurnEnd @@ -2416,9 +2416,9 @@ func TestAgent_Detach_SuspendsTurnSnapshotsAndProcessesQueue(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } // Drain stream chunks in the background. @@ -2501,9 +2501,9 @@ func TestAgent_Detach_AfterPriorTurns_ChainsParent(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } // Background drainer. @@ -2563,9 +2563,9 @@ func TestAgent_Detach_RequiresStore(t *testing.T) { }, ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } if err := conn.Detach(); err != nil { t.Fatalf("Detach send: %v", err) @@ -2615,9 +2615,9 @@ func TestAgent_Detach_PendingThenComplete(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } // Drain chunks so the responder isn't blocked. @@ -2752,9 +2752,9 @@ func TestAgent_Detach_StampsHeartbeatOnPendingSnapshot(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) @@ -2957,9 +2957,9 @@ func TestAgent_Detach_SendArtifactPostDetachLandsInSnapshot(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) @@ -3017,9 +3017,9 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) @@ -3080,9 +3080,9 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) @@ -3139,9 +3139,9 @@ func TestAgent_Detach_NormalCompletionStillEmitsTurnEnd(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendText(t, conn, "hi") @@ -3191,9 +3191,9 @@ func TestAgent_Detach_ClientDisconnectBeforeDetachCancels(t *testing.T) { ) ctx, cancel := context.WithCancel(context.Background()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) @@ -4145,9 +4145,9 @@ func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { // First invocation: detach to write a pending snapshot, then wait // for finalize. - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) sendText(t, conn, "turn 1") @@ -4284,9 +4284,9 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) @@ -4544,9 +4544,9 @@ func TestAgent_FinishReason_TurnAndInvocation(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendText(t, conn, "hi") @@ -4591,9 +4591,9 @@ func TestAgent_FinishReason_OmittedWhenNil(t *testing.T) { }, ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } turnEnd := sendTurn(t, conn, "hi") if turnEnd.FinishReason != "" { @@ -4629,9 +4629,9 @@ func TestAgent_FinishReason_InvocationOverride(t *testing.T) { }, ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } turnEnd := sendTurn(t, conn, "hi") if turnEnd.FinishReason != AgentFinishReasonStop { @@ -4668,9 +4668,9 @@ func TestAgent_FinishReason_MultiTurnDistinct(t *testing.T) { }, ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } var got []AgentFinishReason @@ -4713,9 +4713,9 @@ func TestPromptAgent_ForwardsFinishReason(t *testing.T) { af := DefineAgent[testState](reg, "lengthPrompt", FromPrompt()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } turnEnd := sendTurn(t, conn, "hi") if turnEnd.FinishReason != AgentFinishReasonLength { @@ -4764,9 +4764,9 @@ func TestAgent_Detach_BackgroundWorkSurvivesActionReturn(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("iteration %d: StreamBidi: %v", i, err) + t.Fatalf("iteration %d: Connect: %v", i, err) } drainInBackground(conn) if err := conn.SendText("go"); err != nil { @@ -4821,9 +4821,9 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) sendText(t, conn, "go") @@ -4869,9 +4869,9 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) sendText(t, conn, "go") @@ -4916,9 +4916,9 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) sendText(t, conn, "go") @@ -5016,9 +5016,9 @@ func TestAgent_FinishReason_MultiTurnDistinct_Persisted(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } var ids []string for i := 0; i < len(reasons); i++ { @@ -5064,9 +5064,9 @@ func TestAgent_FinishReason_OmittedPersisted(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendText(t, conn, "hi") snapID := nextTurnEnd(t, conn).SnapshotID @@ -5116,9 +5116,9 @@ func TestPromptAgent_ForwardsInterruptedFinishReason(t *testing.T) { af := DefineAgent[testState](reg, "interruptPrompt", FromPrompt()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendText(t, conn, "do it") var ( @@ -5184,9 +5184,9 @@ func TestAgent_Detach_CompletedHonorsResultOverride(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) sendText(t, conn, "go") @@ -5222,9 +5222,9 @@ func TestAgent_SessionID_AssignedAndStable(t *testing.T) { store := newTestInMemStore[testState]() af := defineLastGoodTestAgent(reg, "sessionAssignFlow", WithSessionStore(store)) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } var snapshotIDs []string @@ -5555,9 +5555,9 @@ func TestAgent_ResumeFromSessionID_AfterFailureResumesLastTurn(t *testing.T) { store := newTestInMemStore[testState]() af := defineLastGoodTestAgent(reg, "sessionRecoveryFlow", WithSessionStore[testState](store)) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } for _, text := range []string{"one", "two", "boom"} { sendText(t, conn, text) @@ -5800,9 +5800,9 @@ func TestAgent_Detach_AssignsSessionID(t *testing.T) { WithSessionStore(store), ) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) sendText(t, conn, "go") @@ -5873,9 +5873,9 @@ func TestAgent_Detach_WaitsForInFlightTurnSnapshot(t *testing.T) { } af := defineLastGoodTestAgent(reg, "detachMidWrite", WithSessionStore[testState](store)) - conn, err := af.StreamBidi(context.Background()) + conn, err := af.Connect(context.Background()) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } drainInBackground(conn) sendText(t, conn, "one") @@ -5938,9 +5938,9 @@ func TestAgent_FailedTurn_OutputCarriesSessionID(t *testing.T) { store := newTestInMemStore[testState]() af := defineLastGoodTestAgent(reg, "failedSessionFlow", WithSessionStore(store)) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } sendTurn(t, conn, "turn one") if err := conn.SendText("boom"); err != nil && !errors.Is(err, core.ErrActionCompleted) { @@ -6018,24 +6018,24 @@ func TestAgent_WithSessionID_OptionValidation(t *testing.T) { store := newTestInMemStore[testState]() af := defineLastGoodTestAgent(reg, "sessionOptFlow", WithSessionStore(store)) - if _, err := af.StreamBidi(ctx, WithState(&SessionState[testState]{}), WithSnapshotID[testState]("x")); err == nil || + if _, err := af.Connect(ctx, WithState(&SessionState[testState]{}), WithSnapshotID[testState]("x")); err == nil || !strings.Contains(err.Error(), "mutually exclusive") { t.Errorf("WithState+WithSnapshotID: expected mutual-exclusion error, got %v", err) } - if _, err := af.StreamBidi(ctx, WithSessionID[testState]("s"), WithSessionID[testState]("s2")); err == nil || + if _, err := af.Connect(ctx, WithSessionID[testState]("s"), WithSessionID[testState]("s2")); err == nil || !strings.Contains(err.Error(), "more than once") { t.Errorf("WithSessionID twice: expected duplicate-option error, got %v", err) } // An empty session ID is an explicit error, not a silent no-op: a // pipelined AgentOutput.SessionID from a storeless invocation must not // quietly start a fresh conversation. - if _, err := af.StreamBidi(ctx, WithSessionID[testState]("")); err == nil || + if _, err := af.Connect(ctx, WithSessionID[testState]("")); err == nil || !strings.Contains(err.Error(), "session ID is empty") { t.Errorf("WithSessionID(\"\"): expected empty-ID error, got %v", err) } // WithSessionID composes with WithSnapshotID: the option layer accepts // the pair and the init-level checks (here: unknown snapshot) decide. - conn, err := af.StreamBidi(ctx, WithSessionID[testState]("s"), WithSnapshotID[testState]("x")) + conn, err := af.Connect(ctx, WithSessionID[testState]("s"), WithSnapshotID[testState]("x")) if err != nil { t.Fatalf("WithSessionID+WithSnapshotID: expected option layer to accept, got %v", err) } @@ -6153,9 +6153,9 @@ func TestAgent_SendNilInput_Rejected(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } if err := conn.Send(nil); err == nil { @@ -6251,9 +6251,9 @@ func TestAgent_ClientCancelMidStream(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } sendText(t, conn, "hello") // Close the input side so sess.Run ends cleanly and fn returns @@ -6312,9 +6312,9 @@ func TestAgent_OutputUnblocksOnCancel(t *testing.T) { ) ctx, cancel := context.WithCancel(context.Background()) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi failed: %v", err) + t.Fatalf("Connect failed: %v", err) } cancel() diff --git a/go/ai/exp/custompatch_test.go b/go/ai/exp/custompatch_test.go index 0a35600cd6..ced24c9f94 100644 --- a/go/ai/exp/custompatch_test.go +++ b/go/ai/exp/custompatch_test.go @@ -63,9 +63,9 @@ func TestCustomPatch_PerTurnRebaseAndIncremental(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } defer conn.Output() @@ -117,9 +117,9 @@ func TestCustomPatch_ClientTracksLiveCustom(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } // Before any patch, Custom is the zero value. @@ -179,9 +179,9 @@ func TestCustomPatch_HonorsStateTransform(t *testing.T) { }), ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } defer conn.Output() @@ -227,9 +227,9 @@ func TestCustomPatch_ConcurrentMutations(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } conn.SendText("go") @@ -268,9 +268,9 @@ func TestCustomPatch_NoMutationNoPatch(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } defer conn.Output() @@ -297,9 +297,9 @@ func TestCustomPatch_EmptyDiffEmitsNothing(t *testing.T) { }, ) - conn, err := af.StreamBidi(ctx) + conn, err := af.Connect(ctx) if err != nil { - t.Fatalf("StreamBidi: %v", err) + t.Fatalf("Connect: %v", err) } defer conn.Output() diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index e620281622..1ed5264129 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -93,7 +93,7 @@ func WithDescription[State any](description string) AgentOption[State] { // --- InvocationOption --- -// InvocationOption configures an agent invocation (StreamBidi, Run, or RunText). +// InvocationOption configures an agent invocation (Connect, Run, or RunText). type InvocationOption[State any] interface { applyInvocation(*invocationOptions[State]) error } diff --git a/go/samples/basic-agents/cli.go b/go/samples/basic-agents/cli.go index 051e1ed437..c675a02ac9 100644 --- a/go/samples/basic-agents/cli.go +++ b/go/samples/basic-agents/cli.go @@ -311,7 +311,7 @@ func runChat(ctx context.Context, inputCh <-chan string, a *aix.Agent[any], resu if resume != nil { opts = append(opts, resumeOption(ctx, a, resume)) } - conn, err := a.StreamBidi(ctx, opts...) + conn, err := a.Connect(ctx, opts...) if err != nil { return prevSessionID, fmt.Errorf("open agent %q: %w", a.Name(), err) } diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index e38e505d4d..d3ffad4db2 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -81,7 +81,7 @@ func main() { // Each define function registers an agent and returns it. The CLI // drives all three through the same surface: a.Name() and - // a.Desc().Description for the list view, a.StreamBidi(...) to chat, + // a.Desc().Description for the list view, a.Connect(...) to chat, // and a.Store() for snapshot reads. Nothing the CLI does is tied to a // concrete store type, so swapping in a different SessionStore would // not touch a line of it. From 79608f436007d41ece3de9a18c9891e79954b115 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 20 Jun 2026 13:39:42 -0700 Subject: [PATCH 123/141] feat(go/exp): add agent snapshot facade for transform-applied reads Route snapshot reads and aborts through the Agent rather than the raw store: Agent.GetSnapshot, Agent.GetLatestSnapshot, and Agent.AbortSnapshot. The read methods apply the configured WithStateTransform, fixing a silent gap where a direct store.GetSnapshot returned raw, unredacted state. A shared readSnapshot helper backs both the methods and the getSnapshot companion action, so Go callers, the Dev UI, and non-Go clients see identical shaping. The store interface stays exported and minimal (third-party stores implement it); abort stays an unexported SaveSnapshot flip reached via the agent method and the abortSnapshot companion action. --- go/ai/exp/agent.go | 85 ++++++++++++++++--- go/ai/exp/agent_test.go | 175 +++++++++++++++++++++++++++++++++++++--- go/ai/exp/session.go | 147 ++++++++++++++++++--------------- 3 files changed, 322 insertions(+), 85 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 7c380438cd..0ecf36905b 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -429,6 +429,11 @@ type Agent[State any] struct { // agent. Retained so callers can reach it via Store without threading // a separate reference. store SessionStore[State] + // transform shapes session state on the way out to a client; see + // [WithStateTransform]. Retained so the typed read facade ([Agent.GetSnapshot], + // [Agent.GetLatestSnapshot]) applies it, matching the getSnapshot companion + // action. Nil when none was configured. + transform StateTransform[State] } // Name returns the agent's registered name. This is also the name under @@ -445,8 +450,8 @@ func (a *Agent[State]) Name() string { // snapshot to fetch. // // Use it to expose snapshot polling over a transport (e.g. mount it with -// genkit.Handler next to the agent itself); local Go code should read -// from the store directly. +// genkit.Handler next to the agent itself); local Go code should use +// [Agent.GetSnapshot], which applies the configured state transform. func (a *Agent[State]) GetSnapshotAction() api.Action { return a.getSnapshot } @@ -458,16 +463,20 @@ func (a *Agent[State]) GetSnapshotAction() api.Action { // [SessionStore] or the store does not implement [SnapshotSubscriber]. // // Use it to expose aborting over a transport (e.g. mount it with -// genkit.Handler next to the agent itself); local Go code aborts by writing -// the aborted status through the store's [SnapshotWriter.SaveSnapshot]. +// genkit.Handler next to the agent itself); local Go code aborts with +// [Agent.AbortSnapshot]; a store-only caller uses this companion action. func (a *Agent[State]) AbortSnapshotAction() api.Action { return a.abortSnapshot } // Store returns the [SessionStore] the agent was configured with via // [WithSessionStore], or nil when the agent is client-managed (no store). -// It lets local Go code read and write snapshots directly given an agent -// reference, without threading a separate store variable. +// +// For reads and aborts prefer the typed facade [Agent.GetSnapshot], +// [Agent.GetLatestSnapshot], and [Agent.AbortSnapshot]: they apply the +// configured [WithStateTransform] and read-time shaping. Store exposes the raw +// backend for advanced use; a direct [SnapshotReader.GetSnapshot] returns +// untransformed state. // // The store is returned as the [SessionStore] interface, not its concrete // type; a caller needing a store-specific capability (e.g. @@ -476,6 +485,61 @@ func (a *Agent[State]) Store() SessionStore[State] { return a.store } +// GetSnapshot fetches a session snapshot by ID through the agent, applying the +// configured [WithStateTransform] and the same read-time shaping the getSnapshot +// companion action performs (a stale-heartbeat pending row is surfaced as +// [SnapshotStatusExpired]; an empty status or zero UpdatedAt is defaulted). +// Prefer it to reading [Agent.Store] directly, which returns raw, untransformed +// state. +// +// It returns FAILED_PRECONDITION on a client-managed agent (no store) and +// INVALID_ARGUMENT when snapshotID is empty; a missing snapshot is NOT_FOUND. +func (a *Agent[State]) GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) { + if a.store == nil { + return nil, core.NewError(core.FAILED_PRECONDITION, "agent %q: GetSnapshot requires a session store", a.Name()) + } + if snapshotID == "" { + return nil, core.NewError(core.INVALID_ARGUMENT, "agent %q: GetSnapshot: snapshotID is required", a.Name()) + } + return readSnapshot(ctx, a.store, a.transform, snapshotID, "") +} + +// GetLatestSnapshot fetches a session's most recently created snapshot (whatever +// its status) through the agent, with the same transform and shaping as +// [Agent.GetSnapshot]. It is the transform-applying counterpart to +// [SnapshotReader.GetLatestSnapshot] and backs resume-by-session lookups. +// +// It returns FAILED_PRECONDITION on a client-managed agent and INVALID_ARGUMENT +// when sessionID is empty; an unknown session is NOT_FOUND. +func (a *Agent[State]) GetLatestSnapshot(ctx context.Context, sessionID string) (*SessionSnapshot[State], error) { + if a.store == nil { + return nil, core.NewError(core.FAILED_PRECONDITION, "agent %q: GetLatestSnapshot requires a session store", a.Name()) + } + if sessionID == "" { + return nil, core.NewError(core.INVALID_ARGUMENT, "agent %q: GetLatestSnapshot: sessionID is required", a.Name()) + } + return readSnapshot(ctx, a.store, a.transform, "", sessionID) +} + +// AbortSnapshot aborts the detached invocation behind a pending snapshot by +// flipping it to [SnapshotStatusAborted]; the runtime observes the flip and +// cancels the background work. A caller that has only a store (no agent) aborts +// through the abortSnapshot companion action instead. It is a no-op on a missing +// snapshot (returns +// "") or an already-terminal one (returns the existing status). +// +// It returns FAILED_PRECONDITION on a client-managed agent and INVALID_ARGUMENT +// when snapshotID is empty. +func (a *Agent[State]) AbortSnapshot(ctx context.Context, snapshotID string) (SnapshotStatus, error) { + if a.store == nil { + return "", core.NewError(core.FAILED_PRECONDITION, "agent %q: AbortSnapshot requires a session store", a.Name()) + } + if snapshotID == "" { + return "", core.NewError(core.INVALID_ARGUMENT, "agent %q: AbortSnapshot: snapshotID is required", a.Name()) + } + return abortPendingSnapshot(ctx, a.store, snapshotID) +} + // --- api.BidiAction implementation --- // Agent is itself an [api.BidiAction]: transports that accept an @@ -658,6 +722,7 @@ func NewCustomAgent[State any]( getSnapshot: getSnapshot, abortSnapshot: abortSnapshot, store: cfg.store, + transform: cfg.transform, } } @@ -1214,13 +1279,15 @@ func beatHeartbeat[State any](ctx context.Context, store SnapshotWriter[State], return err } -// abortSnapshot flips a pending snapshot to aborted via an ordinary +// abortPendingSnapshot flips a pending snapshot to aborted via an ordinary // SaveSnapshot and returns the resulting status: aborted when the row was // pending, the existing terminal status when it was already settled (a no-op // verbatim rewrite), or "" when the snapshot does not exist. SaveSnapshot's // atomic read-mutate-write makes the flip safe against a racing terminal write, -// and the status change drives any [SnapshotSubscriber.OnSnapshotStatusChange] -// subscription, so the store needs no dedicated abort method. +// and the status change drives the runtime's +// [SnapshotSubscriber.OnSnapshotStatusChange] subscription, so the store needs +// no dedicated abort method. It backs [Agent.AbortSnapshot] and the abortSnapshot +// companion action. func abortPendingSnapshot[State any](ctx context.Context, store SnapshotWriter[State], snapshotID string) (SnapshotStatus, error) { now := time.Now() saved, err := store.SaveSnapshot(ctx, snapshotID, diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 3bfebe218d..adff34a3b8 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -3058,7 +3058,7 @@ func TestAgent_Detach_FlowErrorsBecomesError(t *testing.T) { } func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { - // Client detaches, then calls AbortSnapshot. The store's status + // Client detaches, then calls abortPendingSnapshot. The store's status // subscriber notifies the runtime, which cancels the work context, and // the finalizer rewrites the snapshot with status=aborted. reg := newTestRegistry(t) @@ -3104,10 +3104,10 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { // reference from WithSessionStore. status, err := abortPendingSnapshot(context.Background(), store, out.SnapshotID) if err != nil { - t.Fatalf("AbortSnapshot: %v", err) + t.Fatalf("abortPendingSnapshot: %v", err) } if status != SnapshotStatusAborted { - t.Errorf("AbortSnapshot status = %q, want aborted", status) + t.Errorf("abortPendingSnapshot status = %q, want aborted", status) } // The subscriber wakes the runtime, cancels work, and the finalizer @@ -3116,7 +3116,7 @@ func TestAgent_Detach_AbortSnapshotStopsFlow(t *testing.T) { return s.Status == SnapshotStatusAborted && s.UpdatedAt.After(s.CreatedAt) }) // The flow only blocked on ctx — no state mutation expected. State - // may be nil (when AbortSnapshot landed before the finalizer's write + // may be nil (when abortPendingSnapshot landed before the finalizer's write // could populate it) or a populated zero-value struct. if finalSnap.State != nil && finalSnap.State.Custom.Counter != 0 { t.Errorf("unexpected counter value in aborted snapshot: %d", finalSnap.State.Custom.Counter) @@ -3339,6 +3339,161 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } } +// assertGenkitStatus fails the test unless err is a *core.GenkitError with the +// given status. +func assertGenkitStatus(t *testing.T, err error, want core.StatusName, label string) { + t.Helper() + var ge *core.GenkitError + if !errors.As(err, &ge) { + t.Errorf("%s: expected *core.GenkitError, got %v", label, err) + return + } + if ge.Status != want { + t.Errorf("%s: status = %q, want %q", label, ge.Status, want) + } +} + +// TestAgent_GetSnapshot_FacadeTransformsRawStoreDoesNot is the crux of the +// agent-as-facade design: reads through the agent apply the configured +// [WithStateTransform], while a direct store read returns raw state. A caller +// that reaches past the agent into the store therefore sees unredacted data, +// which is exactly why GetSnapshot/GetLatestSnapshot live on the agent. +func TestAgent_GetSnapshot_FacadeTransformsRawStoreDoesNot(t *testing.T) { + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { + for _, msg := range s.Messages { + for _, p := range msg.Content { + if p.Text != "" { + p.Text = strings.ReplaceAll(p.Text, "secret", "[REDACTED]") + } + } + } + return s + } + + af := DefineCustomAgent(reg, "facadeTransform", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("the secret is out")) + return nil, nil + }) + }, + WithSessionStore(store), + WithStateTransform[testState](transform), + ) + + ctx := context.Background() + out, err := af.RunText(ctx, "tell me the secret") + if err != nil { + t.Fatalf("RunText: %v", err) + } + + hasSecret := func(snap *SessionSnapshot[testState]) bool { + for _, msg := range snap.State.Messages { + for _, p := range msg.Content { + if strings.Contains(p.Text, "secret") { + return true + } + } + } + return false + } + + // Through the agent: transformed (redacted). + viaAgent, err := af.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("agent.GetSnapshot: %v", err) + } + if hasSecret(viaAgent) { + t.Error("agent.GetSnapshot returned untransformed state (found 'secret')") + } + + // GetLatestSnapshot resolves the same row by session and also transforms. + viaLatest, err := af.GetLatestSnapshot(ctx, out.SessionID) + if err != nil { + t.Fatalf("agent.GetLatestSnapshot: %v", err) + } + if viaLatest.SnapshotID != out.SnapshotID { + t.Errorf("GetLatestSnapshot id = %q, want %q", viaLatest.SnapshotID, out.SnapshotID) + } + if hasSecret(viaLatest) { + t.Error("agent.GetLatestSnapshot returned untransformed state (found 'secret')") + } + + // Directly via the store: raw, untransformed. This contrast is the point. + raw, err := store.GetSnapshot(ctx, out.SnapshotID) + if err != nil { + t.Fatalf("store.GetSnapshot: %v", err) + } + if !hasSecret(raw) { + t.Error("raw store.GetSnapshot should retain the original 'secret' text") + } +} + +// TestAgent_SnapshotFacade_Errors covers the guard rails on the agent facade: +// a client-managed agent has no store, and empty IDs / missing rows are +// rejected before touching the store. +func TestAgent_SnapshotFacade_Errors(t *testing.T) { + reg := newTestRegistry(t) + ctx := context.Background() + + // Client-managed agent (no store): every facade method is FAILED_PRECONDITION. + client := defineCounterAgent(reg, "facadeClient") + _, e1 := client.GetSnapshot(ctx, "x") + assertGenkitStatus(t, e1, core.FAILED_PRECONDITION, "client GetSnapshot") + _, e2 := client.GetLatestSnapshot(ctx, "x") + assertGenkitStatus(t, e2, core.FAILED_PRECONDITION, "client GetLatestSnapshot") + _, e3 := client.AbortSnapshot(ctx, "x") + assertGenkitStatus(t, e3, core.FAILED_PRECONDITION, "client abortPendingSnapshot") + + // Server-managed agent: empty IDs are INVALID_ARGUMENT, missing rows NOT_FOUND. + store := newTestInMemStore[testState]() + server := defineCounterAgent(reg, "facadeServer", WithSessionStore(store)) + _, e4 := server.GetSnapshot(ctx, "") + assertGenkitStatus(t, e4, core.INVALID_ARGUMENT, "empty GetSnapshot") + _, e5 := server.GetLatestSnapshot(ctx, "") + assertGenkitStatus(t, e5, core.INVALID_ARGUMENT, "empty GetLatestSnapshot") + _, e6 := server.AbortSnapshot(ctx, "") + assertGenkitStatus(t, e6, core.INVALID_ARGUMENT, "empty abortPendingSnapshot") + _, e7 := server.GetSnapshot(ctx, "missing") + assertGenkitStatus(t, e7, core.NOT_FOUND, "missing GetSnapshot") +} + +// TestAgent_AbortSnapshot_Method verifies the in-process convenience flips a +// pending row to aborted through the store, mirroring the package-level +// [abortPendingSnapshot] and the abortSnapshot companion action. +func TestAgent_AbortSnapshot_Method(t *testing.T) { + reg := newTestRegistry(t) + ctx := context.Background() + store := newTestInMemStore[testState]() + af := defineCounterAgent(reg, "facadeAbort", WithSessionStore(store)) + + // Seed a pending row the way a detached turn would. + pending, err := store.SaveSnapshot(ctx, "", func(*SessionSnapshot[testState]) (*SessionSnapshot[testState], error) { + return &SessionSnapshot[testState]{SessionID: "s1", Status: SnapshotStatusPending}, nil + }) + if err != nil { + t.Fatalf("seed pending: %v", err) + } + + status, err := af.AbortSnapshot(ctx, pending.SnapshotID) + if err != nil { + t.Fatalf("agent.AbortSnapshot: %v", err) + } + if status != SnapshotStatusAborted { + t.Errorf("returned status = %q, want aborted", status) + } + got, err := store.GetSnapshot(ctx, pending.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if got.Status != SnapshotStatusAborted { + t.Errorf("stored status = %q, want aborted", got.Status) + } +} + // TestAgent_GetSnapshotAction_ReturnsFinishReason verifies the remote // getSnapshot companion action surfaces the persisted finish reason, so a // non-Go client or the Dev UI polling a detached/background invocation can @@ -4303,7 +4458,7 @@ func TestAgent_Detach_FinalizeRespectsConcurrentAbort(t *testing.T) { // Externally abort before releasing fn. if _, err := abortPendingSnapshot(context.Background(), store, out.SnapshotID); err != nil { - t.Fatalf("AbortSnapshot: %v", err) + t.Fatalf("abortPendingSnapshot: %v", err) } close(fnRelease) @@ -4358,7 +4513,7 @@ func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { // Abort flips status; subscriber observes aborted. if _, err := abortPendingSnapshot(ctx, store, "snap-sub"); err != nil { - t.Fatalf("AbortSnapshot: %v", err) + t.Fatalf("abortPendingSnapshot: %v", err) } select { case status, ok := <-statusCh: @@ -4385,7 +4540,7 @@ func TestInMemorySessionStore_OnSnapshotStatusChange(t *testing.T) { } func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { - // Calling AbortSnapshot on an already-terminal snapshot is a no-op + // Calling abortPendingSnapshot on an already-terminal snapshot is a no-op // that returns the existing status. reg := newTestRegistry(t) store := newTestInMemStore[testState]() @@ -4408,7 +4563,7 @@ func TestAgent_AbortSnapshot_NoOpOnTerminal(t *testing.T) { status, err := abortPendingSnapshot(ctx, store, out.SnapshotID) if err != nil { - t.Fatalf("AbortSnapshot: %v", err) + t.Fatalf("abortPendingSnapshot: %v", err) } if status != SnapshotStatusCompleted { t.Errorf("expected status=%q (existing terminal), got %q", SnapshotStatusCompleted, status) @@ -4932,9 +5087,9 @@ func TestAgent_Detach_FinishReasons(t *testing.T) { t.Fatalf("Output: %v", err) } if _, err := abortPendingSnapshot(context.Background(), store, out.SnapshotID); err != nil { - t.Fatalf("AbortSnapshot: %v", err) + t.Fatalf("abortPendingSnapshot: %v", err) } - // AbortSnapshot flips status=aborted (finishReason still empty); the + // abortPendingSnapshot flips status=aborted (finishReason still empty); the // finalizer then annotates the row with finishReason=aborted. Wait // for that second write rather than the bare status flip. snap := waitForSnapshot(t, store, out.SnapshotID, 2*time.Second, func(s *SessionSnapshot[testState]) bool { diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index b2311e3823..4f8760eccd 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -194,6 +194,86 @@ func cloneArtifacts(arts []*Artifact) []*Artifact { // The [Agent] retains the returned actions (an absent one is nil) and // registers them alongside its run action; see [Agent.Register], // [Agent.GetSnapshotAction], and [Agent.AbortSnapshotAction]. +// readSnapshot resolves a snapshot by ID, or by the session's latest when +// snapshotID is empty, and returns a normalized copy shaped for a client: +// the documented defaults are applied (empty status means completed, zero +// UpdatedAt means CreatedAt), a pending row whose heartbeat has gone stale is +// surfaced as [SnapshotStatusExpired] (computed on read, never written back), +// and transform shapes the outbound state. It backs both the getSnapshot +// companion action and the typed [Agent.GetSnapshot] / [Agent.GetLatestSnapshot], +// so Go callers, the Dev UI, and non-Go clients all observe identically shaped +// snapshots. At least one of snapshotID / sessionID must be non-empty; callers +// validate that before calling. +func readSnapshot[State any]( + ctx context.Context, + store SnapshotReader[State], + transform StateTransform[State], + snapshotID, sessionID string, +) (*SessionSnapshot[State], error) { + // Resolve the snapshot. A snapshot ID fetches that exact row; a session ID + // alone fetches the session's latest row (whatever its status). When both + // are present the snapshot ID picks the row and the session ID asserts it + // belongs to that session, mirroring AgentInit's combined-ID check. + var ( + snap *SessionSnapshot[State] + err error + ) + if snapshotID != "" { + snap, err = store.GetSnapshot(ctx, snapshotID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) + } + if snap == nil { + return nil, core.NewError(core.NOT_FOUND, "getSnapshot: snapshot %q not found", snapshotID) + } + if sessionID != "" && snap.SessionID != sessionID { + return nil, core.NewError(core.INVALID_ARGUMENT, + "getSnapshot: snapshot %q does not belong to session %q (snapshot's session: %q)", snapshotID, sessionID, snap.SessionID) + } + } else { + snap, err = store.GetLatestSnapshot(ctx, sessionID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) + } + if snap == nil { + return nil, core.NewError(core.NOT_FOUND, "getSnapshot: no snapshot found for session %q", sessionID) + } + } + + // Return a normalized copy: the documented defaults (empty status means + // completed, zero UpdatedAt means CreatedAt) are resolved here so every + // caller sees the same shaping, and the state transform shapes what leaves + // the server. A failed snapshot's state is its last-good state, so it is + // returned like any other. + resp := *snap + // Surface a pending snapshot whose heartbeat has gone stale as expired: its + // detached background worker is presumed dead, so report the orphan rather + // than leaving it pending forever. Computed on read only, never written back + // to the store, so the raw row stays pending. Checked before the empty-status + // default below, which applies only to a row carrying no status at all. + if isHeartbeatExpired(snap, defaultHeartbeatTimeout) { + resp.Status = SnapshotStatusExpired + } + if resp.Status == "" { + resp.Status = SnapshotStatusCompleted + } + if resp.UpdatedAt.IsZero() { + resp.UpdatedAt = resp.CreatedAt + } + // Clone before transforming: the [StateTransform] contract promises a fresh + // deep copy the transform may mutate in place, and the store's row may share + // memory with its internal copy, which neither the transform nor the SessionID + // re-stamp below may write into. + resp.State = applyTransform(ctx, transform, jsonClone(snap.State)) + if resp.State != nil { + // SessionID is framework identity, not user data: re-stamp it from the + // row after the transform so outbound state always agrees with the + // snapshot it came from. + resp.State.SessionID = resp.SessionID + } + return &resp, nil +} + func newSnapshotActions[State any]( agentName string, store SessionStore[State], @@ -208,72 +288,7 @@ func newSnapshotActions[State any]( return nil, core.NewError(core.INVALID_ARGUMENT, "getSnapshot: snapshotId or sessionId is required") } - // Resolve the snapshot. A snapshot ID fetches that exact row; a - // session ID alone fetches the session's latest row (whatever - // its status). When both are present the snapshot ID picks the - // row and the session ID asserts it belongs to that session, - // mirroring AgentInit's combined-ID check. - var ( - snap *SessionSnapshot[State] - err error - ) - if req.SnapshotID != "" { - snap, err = store.GetSnapshot(ctx, req.SnapshotID) - if err != nil { - return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) - } - if snap == nil { - return nil, core.NewError(core.NOT_FOUND, "getSnapshot: snapshot %q not found", req.SnapshotID) - } - if req.SessionID != "" && snap.SessionID != req.SessionID { - return nil, core.NewError(core.INVALID_ARGUMENT, - "getSnapshot: snapshot %q does not belong to session %q (snapshot's session: %q)", req.SnapshotID, req.SessionID, snap.SessionID) - } - } else { - snap, err = store.GetLatestSnapshot(ctx, req.SessionID) - if err != nil { - return nil, core.NewError(core.INTERNAL, "getSnapshot: %v", err) - } - if snap == nil { - return nil, core.NewError(core.NOT_FOUND, "getSnapshot: no snapshot found for session %q", req.SessionID) - } - } - - // Return a normalized copy: the documented defaults (empty - // status means completed, zero UpdatedAt means CreatedAt) are - // resolved server-side so remote clients don't reimplement - // them, and the state transform shapes what leaves the server. - // A failed snapshot's state is its last-good state, so it is - // returned like any other. - resp := *snap - // Surface a pending snapshot whose heartbeat has gone stale as - // expired: its detached background worker is presumed dead, so - // report the orphan rather than leaving it pending forever. This is - // computed on read only, never written back to the store, so the - // raw row stays pending. Checked before the empty-status default - // below, which applies only to a row carrying no status at all. - if isHeartbeatExpired(snap, defaultHeartbeatTimeout) { - resp.Status = SnapshotStatusExpired - } - if resp.Status == "" { - resp.Status = SnapshotStatusCompleted - } - if resp.UpdatedAt.IsZero() { - resp.UpdatedAt = resp.CreatedAt - } - // Clone before transforming: the [StateTransform] contract - // promises a fresh deep copy the transform may mutate in - // place, and the store's row may share memory with its - // internal copy, which neither the transform nor the - // SessionID re-stamp below may write into. - resp.State = applyTransform(ctx, transform, jsonClone(snap.State)) - if resp.State != nil { - // SessionID is framework identity, not user data: re-stamp - // it from the row after the transform so outbound state - // always agrees with the snapshot it came from. - resp.State.SessionID = resp.SessionID - } - return &resp, nil + return readSnapshot(ctx, store, transform, req.SnapshotID, req.SessionID) }) if _, ok := store.(SnapshotSubscriber); !ok { From 0e0f41b8a962c503379469e0cb489b56a53670f1 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 22 Jun 2026 14:23:12 -0700 Subject: [PATCH 124/141] fix(go/core): validate and normalize bidi JSON init like action input decodeInit decoded the JSON init payload straight into the typed Init, so any field the InitSchema did not declare was silently dropped before validation ever saw it: an agent could be started with a misspelled or bogus init field and get a fresh session with no error. Route init through base.UnmarshalAndNormalize against the resolved InitSchema, the same pipeline the unary input path uses, so the raw payload is validated (additionalProperties rejects unknown fields) and normalized (e.g. integer widening) before the typed decode. Add coverage for unknown-field rejection on ConnectJSON/RunBidiJSON at both the core BidiAction layer and the agent layer (real inferred AgentInit schema), plus a guard that init is normalized, not just validated. --- go/ai/exp/agent_test.go | 58 ++++++++++++++++++++ go/core/bidi.go | 24 +++++++-- go/core/bidi_test.go | 114 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 192 insertions(+), 4 deletions(-) diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index adff34a3b8..e698ffa42c 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -18,6 +18,7 @@ package exp import ( "context" + "encoding/json" "errors" "fmt" "slices" @@ -1136,6 +1137,63 @@ func TestAgent_InitFailure_FailsActionWithStatus(t *testing.T) { } } +// TestAgent_JSONInitRejectsUnknownFields verifies that the agent's JSON init +// paths reject a payload carrying a field the inferred AgentInit schema does not +// declare, surfacing INVALID_ARGUMENT rather than silently dropping it and +// starting a fresh session as if nothing were wrong. This exercises the real +// inferred init schema end to end (the core mechanism is covered by +// TestBidiJSONInitRejectsUnknownFields). +func TestAgent_JSONInitRejectsUnknownFields(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + echo := func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + return nil, nil + }) + } + af := DefineCustomAgent(reg, "jsonInitUnknownFields", echo) + + assertRejected := func(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Fatal("expected INVALID_ARGUMENT for unknown init field, got nil") + } + if ge := core.AsGenkitError(err); ge.Status != core.INVALID_ARGUMENT { + t.Errorf("status = %q, want %q (err: %v)", ge.Status, core.INVALID_ARGUMENT, err) + } + } + + badInit := json.RawMessage(`{"bogus":true}`) + + t.Run("ConnectJSON", func(t *testing.T) { + _, err := af.ConnectJSON(ctx, &api.BidiJSONOptions{Init: badInit}) + assertRejected(t, err) + }) + + t.Run("RunBidiJSON", func(t *testing.T) { + input, err := json.Marshal(&AgentInput{Message: ai.NewUserTextMessage("hi")}) + if err != nil { + t.Fatalf("marshal input: %v", err) + } + _, err = af.RunBidiJSON(ctx, input, nil, &api.BidiJSONOptions{Init: badInit}) + assertRejected(t, err) + }) + + // An empty init object declares no fields, so it passes validation and + // starts a fresh session. + t.Run("empty init accepted", func(t *testing.T) { + conn, err := af.ConnectJSON(ctx, &api.BidiJSONOptions{Init: json.RawMessage(`{}`)}) + if err != nil { + t.Fatalf("ConnectJSON with empty init: %v", err) + } + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + }) +} + func TestAgent_SetMessages(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) diff --git a/go/core/bidi.go b/go/core/bidi.go index 798f548358..2100aed3cc 100644 --- a/go/core/bidi.go +++ b/go/core/bidi.go @@ -293,15 +293,31 @@ func (b *BidiAction[In, Out, Stream, Init]) ConnectJSON(ctx context.Context, opt } // decodeInit decodes the JSON init payload from opts into the action's Init -// type. Returns hasInit=false when opts is nil or the payload is empty or -// JSON null, so transports can pass the request's init field through -// unconditionally. +// type through the same validate-and-normalize pipeline the unary input path +// uses (base.UnmarshalAndNormalize against the resolved schema): the raw +// payload is validated before it is unmarshaled, and JSON values are normalized +// to the schema (e.g. number widening) exactly as input chunks are. Validating +// the raw payload is what rejects a field the InitSchema does not declare; +// decoding straight into a struct Init would drop it silently, leaving the +// typed-value validateInit that follows nothing to catch (by then the unknown +// field is already gone). +// +// Returns hasInit=false when opts is nil or the payload is empty or JSON null, +// so transports can pass the request's init field through unconditionally. An +// action with no InitSchema (e.g. a struct{} Init) resolves to a nil schema, +// which UnmarshalAndNormalize accepts and normalizes structurally, matching the +// input path's handling of a schemaless action. func (b *BidiAction[In, Out, Stream, Init]) decodeInit(opts *api.BidiJSONOptions) (Init, bool, error) { var init Init if opts == nil || !base.HasJSONValue(opts.Init) { return init, false, nil } - if err := json.Unmarshal(opts.Init, &init); err != nil { + schema, err := ResolveSchema(b.registry, b.desc.InitSchema) + if err != nil { + return init, false, NewError(INVALID_ARGUMENT, "invalid init schema for action %q: %v", b.desc.Key, err) + } + init, err = base.UnmarshalAndNormalize[Init](opts.Init, schema) + if err != nil { return init, false, NewError(INVALID_ARGUMENT, "invalid init for action %q: %v", b.desc.Key, err) } return init, true, nil diff --git a/go/core/bidi_test.go b/go/core/bidi_test.go index c3c579b4c9..4b721cb3e8 100644 --- a/go/core/bidi_test.go +++ b/go/core/bidi_test.go @@ -419,6 +419,120 @@ func TestInitSchemaValidationAcceptsGoodInit(t *testing.T) { } } +// TestBidiJSONInitRejectsUnknownFields verifies that the JSON init paths +// (ConnectJSON, RunBidiJSON) validate the raw init payload against the action's +// InitSchema before unmarshaling, so a field the schema does not declare is +// rejected as INVALID_ARGUMENT. Decoding straight into a struct Init would drop +// the stray field, leaving the post-decode validateInit nothing to catch. +func TestBidiJSONInitRejectsUnknownFields(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string `json:"prefix,omitempty"` + } + + // additionalProperties:false is what makes a stray field a violation; it is + // also what the inferred schema for a struct Init (e.g. an agent's + // *AgentInit) carries by default. + initSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "prefix": map[string]any{"type": "string"}, + }, + "additionalProperties": false, + } + + action := NewBidiAction( + "json-init-unknown-fields", api.ActionTypeCustom, + &BidiActionOptions{InitSchema: initSchema}, + func(ctx context.Context, cfg *Config, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "done", nil + }, + ) + + assertRejected := func(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Fatal("expected INVALID_ARGUMENT for unknown init field, got nil") + } + var gerr *GenkitError + if !errors.As(err, &gerr) || gerr.Status != INVALID_ARGUMENT { + t.Fatalf("err = %v, want INVALID_ARGUMENT GenkitError", err) + } + } + + badInit := json.RawMessage(`{"prefix":">> ","bogus":true}`) + + t.Run("ConnectJSON", func(t *testing.T) { + _, err := action.ConnectJSON(ctx, &api.BidiJSONOptions{Init: badInit}) + assertRejected(t, err) + }) + + t.Run("RunBidiJSON", func(t *testing.T) { + _, err := action.RunBidiJSON(ctx, json.RawMessage(`"hello"`), nil, + &api.BidiJSONOptions{Init: badInit}) + assertRejected(t, err) + }) + + // A payload with only declared fields still starts the session. + t.Run("known fields accepted", func(t *testing.T) { + conn, err := action.ConnectJSON(ctx, &api.BidiJSONOptions{ + Init: json.RawMessage(`{"prefix":">> "}`)}) + if err != nil { + t.Fatalf("ConnectJSON: %v", err) + } + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + }) +} + +// TestBidiJSONInitNormalizedLikeInput verifies that the JSON init path runs the +// same normalization as the input path: a JSON number for an integer-typed +// field is widened to int64 rather than left as the float64 a plain decode into +// an any value would produce. This pins init handling to the input pipeline +// (base.UnmarshalAndNormalize), not just schema validation. +func TestBidiJSONInitNormalizedLikeInput(t *testing.T) { + ctx := context.Background() + + initSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{"type": "integer"}, + }, + } + + gotCount := make(chan any, 1) + action := NewBidiAction( + "json-init-normalized", api.ActionTypeCustom, + &BidiActionOptions{InitSchema: initSchema}, + func(ctx context.Context, cfg any, inCh <-chan string, outCh chan<- string) (string, error) { + m, _ := cfg.(map[string]any) + gotCount <- m["count"] + for range inCh { + } + return "done", nil + }, + ) + + conn, err := action.ConnectJSON(ctx, &api.BidiJSONOptions{ + Init: json.RawMessage(`{"count":42}`)}) + if err != nil { + t.Fatalf("ConnectJSON: %v", err) + } + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output: %v", err) + } + + if got := <-gotCount; got != int64(42) { + t.Errorf("normalized init count = %T (%v), want int64(42)", got, got) + } +} + // TestBidiNilInitSkipsValidation verifies that a nil init (the zero value of // a pointer Init type) bypasses init schema validation on every no-init path. // The inferred init schema describes the object form, which JSON null can From 1e9654cfbff78229cb18213f9ecc6e9b1a21e859 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 22 Jun 2026 15:02:16 -0700 Subject: [PATCH 125/141] feat(go/exp): tag agent traces with session ID, match JS turn spans Stamp the agent's root action span with the session ID under "genkit:metadata:agent:sessionId" so traces from the same conversation can be correlated. This mirrors the JS agent (PR #5251), which sets the attribute once on the action span via setCustomMetadataAttributes; Go has no such helper, so the genkit:metadata: prefix is inlined. Align the per-turn span shape with JS's run() too: name "runTurn-N" (1-indexed) with type flowStep and no subtype, replacing "agent/turn/N". --- go/ai/exp/agent.go | 26 ++++++++- go/ai/exp/agent_test.go | 118 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 3 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 0ecf36905b..209c0426e0 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -39,6 +39,8 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // --- Heartbeat --- @@ -189,9 +191,11 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont s.onStartTurn() } spanMeta := &tracing.SpanMetadata{ - Name: fmt.Sprintf("agent/turn/%d", s.turnIndex), - Type: "flowStep", - Subtype: "flowStep", + // Match the JS agent's turn span so cross-language traces line up: + // name "runTurn-N" (1-indexed) and type flowStep with no subtype + // (JS's run() sets only genkit:type, no genkit:metadata:subtype). + Name: fmt.Sprintf("runTurn-%d", s.turnIndex+1), + Type: "flowStep", } _, err := tracing.RunInNewSpan(ctx, spanMeta, input, func(ctx context.Context, input *AgentInput) (any, error) { @@ -806,6 +810,13 @@ type fnDoneResult[State any] struct { err error } +// sessionIDSpanAttrKey is the full span-attribute key under which an agent's +// root action span records its session ID. It is the "genkit:metadata:"-prefixed +// form of the "agent:sessionId" custom-metadata key the JS agent sets via +// setCustomMetadataAttributes; the prefix is inlined here because Go's tracing +// package exposes no setCustomMetadataAttributes helper. +const sessionIDSpanAttrKey = "genkit:metadata:agent:sessionId" + func newAgentRuntime[State any]( ctx context.Context, name string, @@ -853,6 +864,15 @@ func newAgentRuntime[State any]( session.state.SessionID = uuid.New().String() } + // Tag the agent's root action span (the current span here, before any turn + // span is opened) with the session ID so traces from the same conversation + // can be correlated. Mirrors the JS agent, which calls + // setCustomMetadataAttributes({'agent:sessionId': ...}) once at the start of + // the action body. trace.SpanFromContext never returns nil (it yields a + // no-op span when none is active), so the SetAttributes is always safe. + trace.SpanFromContext(ctx).SetAttributes( + attribute.String(sessionIDSpanAttrKey, session.state.SessionID)) + rt := &agentRuntime[State]{ name: name, cfg: cfg, diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index e698ffa42c..83d7daeee8 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -31,7 +31,9 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/registry" + sdktrace "go.opentelemetry.io/otel/sdk/trace" ) type testState struct { @@ -259,6 +261,122 @@ func TestAgent_WithSessionStore(t *testing.T) { } } +// spanCollector is a minimal in-memory sdktrace.SpanExporter that records +// finished spans so a test can assert on their attributes. +type spanCollector struct { + mu sync.Mutex + spans []sdktrace.ReadOnlySpan +} + +func (c *spanCollector) ExportSpans(_ context.Context, spans []sdktrace.ReadOnlySpan) error { + c.mu.Lock() + defer c.mu.Unlock() + c.spans = append(c.spans, spans...) + return nil +} + +func (c *spanCollector) Shutdown(context.Context) error { return nil } + +// byName returns the first recorded span with the given name, or nil. +func (c *spanCollector) byName(name string) sdktrace.ReadOnlySpan { + c.mu.Lock() + defer c.mu.Unlock() + for _, s := range c.spans { + if s.Name() == name { + return s + } + } + return nil +} + +// collectSpans registers an in-memory exporter on the global tracer provider +// (the one tracing.RunInNewSpan writes through) for the duration of the test. +// The SimpleSpanProcessor exports each span synchronously as it ends, so by +// the time a turn completes its span is already recorded. +func collectSpans(t *testing.T) *spanCollector { + t.Helper() + c := &spanCollector{} + sp := sdktrace.NewSimpleSpanProcessor(c) + tp := tracing.TracerProvider() + tp.RegisterSpanProcessor(sp) + t.Cleanup(func() { tp.UnregisterSpanProcessor(sp) }) + return c +} + +// spanAttr returns the string value of the named span attribute, if present. +func spanAttr(span sdktrace.ReadOnlySpan, key string) (string, bool) { + for _, kv := range span.Attributes() { + if string(kv.Key) == key { + return kv.Value.AsString(), true + } + } + return "", false +} + +// TestAgent_RootSpanCarriesSessionID verifies the agent's root action span is +// stamped with the invocation's session ID under +// "genkit:metadata:agent:sessionId", and that the per-turn spans are left +// untagged. This matches the JS agent, which tags the action span once via +// setCustomMetadataAttributes({'agent:sessionId': ...}) rather than each turn. +func TestAgent_RootSpanCarriesSessionID(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + spans := collectSpans(t) + + const agentName = "tracedCounterFlow" + af := defineCounterAgent(reg, agentName) + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + sendTurn(t, conn, "turn1") + sendTurn(t, conn, "turn2") + conn.Close() + + out, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + if out.SessionID == "" { + t.Fatal("expected a non-empty session ID on the output") + } + + const attrKey = "genkit:metadata:agent:sessionId" + + // The root action span (named for the agent) carries the session ID. + root := spans.byName(agentName) + if root == nil { + t.Fatalf("missing root action span %q", agentName) + } + got, ok := spanAttr(root, attrKey) + if !ok { + t.Fatalf("root span %q: missing attribute %q", agentName, attrKey) + } + if got != out.SessionID { + t.Errorf("root span %q: %s = %q, want %q", agentName, attrKey, got, out.SessionID) + } + + // The per-turn spans are named "runTurn-N" (1-indexed) with type flowStep + // and no subtype, matching JS's run(); the session ID lives on the root + // span only. + for _, turn := range []string{"runTurn-1", "runTurn-2"} { + span := spans.byName(turn) + if span == nil { + t.Fatalf("missing span %q", turn) + } + if v, ok := spanAttr(span, attrKey); ok { + t.Errorf("turn span %q: unexpected attribute %q = %q (want it on the root span only)", turn, attrKey, v) + } + if v, _ := spanAttr(span, "genkit:type"); v != "flowStep" { + t.Errorf("turn span %q: genkit:type = %q, want %q", turn, v, "flowStep") + } + if v, ok := spanAttr(span, "genkit:metadata:subtype"); ok { + t.Errorf("turn span %q: unexpected genkit:metadata:subtype = %q (JS's run() sets no subtype)", turn, v) + } + } +} + func TestAgent_ResumeFromSnapshot(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) From 0bd7227cef261c8dbe9fbc30e018468c878c855b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 22 Jun 2026 21:19:16 -0700 Subject: [PATCH 126/141] feat(go/exp): add pruning and path-prefix options to file session store Add two construction options to localstore.FileSessionStore, mirroring the JS FileSessionStore (PR #5251): - WithMaxPersistedChainLength(n): on each save, walk the new snapshot's parentId chain and unlink rows past the newest n, capping per-conversation disk use. n must be >= 1; 0 or negative is rejected. - WithSnapshotPathPrefix(fn): derive a per-call subdirectory from context for tenant isolation. The prefix may nest via "/" and is sanitized against directory escape. Snapshots remain flat by id within the prefix (//.json), so resume-by-snapshot, heartbeat, finalize, abort, and status subscription stay O(1) direct opens, while GetLatestSnapshot scans the prefix directory. With no options configured the on-disk layout is unchanged. Options follow the ai/option.go interface pattern (in a new option.go) and error if set more than once. --- go/ai/exp/localstore/file.go | 241 ++++++++++++++++++++++-------- go/ai/exp/localstore/file_test.go | 241 ++++++++++++++++++++++++++++++ go/ai/exp/localstore/option.go | 92 ++++++++++++ 3 files changed, 515 insertions(+), 59 deletions(-) create mode 100644 go/ai/exp/localstore/option.go diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index ebee7fd1a2..28236c0c5f 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -32,9 +32,20 @@ import ( "github.com/google/uuid" ) -// FileSessionStore is a snapshot store that persists snapshots as JSON files -// on the local filesystem. Each snapshot is written to its own file named -// ".json" in the configured directory. +// FileSessionStore is a snapshot store that persists snapshots as JSON files on +// the local filesystem. Each snapshot is written to its own file named +// ".json", under an optional per-call subdirectory ("prefix"): +// +// //.json +// +// The snapshot ID is the primary key: GetSnapshot, the by-ID SaveSnapshot +// (heartbeat, abort, finalize), and OnSnapshotStatusChange all open that file +// directly. GetLatestSnapshot, the only by-session lookup, scans the prefix +// directory and selects the most-recently-created row for the session. The +// prefix is derived from each call's context (see [WithSnapshotPathPrefix]), so +// it is always known on a by-ID call - unlike the session ID, which a by-ID +// caller does not have. That is why snapshots are grouped by prefix and kept +// flat within it rather than nested under a per-session directory. // // The store is safe for concurrent use within a single process, but does NOT // coordinate with other processes sharing the directory: the last successful @@ -46,42 +57,66 @@ type FileSessionStore[State any] struct { // File I/O happens under the lock; this matches the simplicity of // [InMemorySessionStore] and is adequate when writes are infrequent // (typically once per turn). - mu sync.Mutex - dir string - subs map[string][]chan exp.SnapshotStatus + mu sync.Mutex + dir string + // maxChain, when > 0, bounds how many snapshots a single conversation chain + // retains on disk; see [WithMaxPersistedChainLength]. + maxChain int + // prefixFn, when set, derives the per-call subdirectory from context; see + // [WithSnapshotPathPrefix]. + prefixFn func(context.Context) string + subs map[string][]chan exp.SnapshotStatus } // NewFileSessionStore creates a file-based snapshot store rooted at dir. // The directory is created (mode 0o700) if it does not already exist. -// Returns an error if dir is empty or cannot be created. -func NewFileSessionStore[State any](dir string) (*FileSessionStore[State], error) { +// Returns an error if dir is empty, cannot be created, or an option is set +// more than once. See [WithMaxPersistedChainLength] and +// [WithSnapshotPathPrefix]. +func NewFileSessionStore[State any](dir string, opts ...FileStoreOption) (*FileSessionStore[State], error) { if dir == "" { return nil, errors.New("FileSessionStore: dir is required") } if err := os.MkdirAll(dir, 0o700); err != nil { return nil, fmt.Errorf("FileSessionStore: create dir %q: %w", dir, err) } + var resolved fileStoreOptions + for _, o := range opts { + if err := o.applyFileStore(&resolved); err != nil { + return nil, err + } + } + maxChain := 0 + if resolved.maxChain != nil { + maxChain = *resolved.maxChain + } return &FileSessionStore[State]{ - dir: dir, - subs: make(map[string][]chan exp.SnapshotStatus), + dir: dir, + maxChain: maxChain, + prefixFn: resolved.prefixFn, + subs: make(map[string][]chan exp.SnapshotStatus), }, nil } // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. -func (s *FileSessionStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*exp.SessionSnapshot[State], error) { +func (s *FileSessionStore[State]) GetSnapshot(ctx context.Context, snapshotID string) (*exp.SessionSnapshot[State], error) { if err := validateSnapshotID(snapshotID); err != nil { return nil, err } + prefix, err := s.derivePrefix(ctx) + if err != nil { + return nil, err + } s.mu.Lock() defer s.mu.Unlock() - return s.readLocked(snapshotID) + return s.readAt(s.pathFor(prefix, snapshotID)) } // SaveSnapshot atomically reads, applies fn, and persists. See // [exp.SnapshotWriter] for the full contract; this implementation calls fn // exactly once per call. func (s *FileSessionStore[State]) SaveSnapshot( - _ context.Context, + ctx context.Context, id string, fn func(existing *exp.SessionSnapshot[State]) (*exp.SessionSnapshot[State], error), ) (*exp.SessionSnapshot[State], error) { @@ -90,11 +125,15 @@ func (s *FileSessionStore[State]) SaveSnapshot( } else if err := validateSnapshotID(id); err != nil { return nil, err } + prefix, err := s.derivePrefix(ctx) + if err != nil { + return nil, err + } s.mu.Lock() defer s.mu.Unlock() - existing, err := s.readLocked(id) + existing, err := s.readAt(s.pathFor(prefix, id)) if err != nil { return nil, err } @@ -117,12 +156,15 @@ func (s *FileSessionStore[State]) SaveSnapshot( next.Status = exp.SnapshotStatusCompleted } - if err := s.writeLocked(next); err != nil { + if err := s.writeAt(prefix, next); err != nil { return nil, err } if existing == nil || existing.Status != next.Status { s.notifyLocked(id, next.Status) } + if s.maxChain > 0 { + s.pruneLocked(prefix, next) + } return next, nil } @@ -136,18 +178,24 @@ type snapshotHeader struct { // GetLatestSnapshot returns the session's most recently created snapshot // regardless of status, per the [exp.SnapshotReader.GetLatestSnapshot] -// contract. +// contract. It scans the call's prefix directory (see [WithSnapshotPathPrefix]), +// so a session is only resolvable under the prefix it was written with. // // Recency is judged by the [exp.SessionSnapshot.CreatedAt] field (not file // mtime), so a later rewrite of an older row - which preserves CreatedAt - does // not move it ahead of a newer-created sibling. Ties are broken by snapshot ID. // A file that fails to parse or vanishes mid-scan is skipped, so one corrupted // row cannot hide every other session. -func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { +func (s *FileSessionStore[State]) GetLatestSnapshot(ctx context.Context, sessionID string) (*exp.SessionSnapshot[State], error) { if sessionID == "" { return nil, errors.New("FileSessionStore: session ID is empty") } - names, err := s.snapshotFileNames() + prefix, err := s.derivePrefix(ctx) + if err != nil { + return nil, err + } + dir := s.prefixDir(prefix) + names, err := s.snapshotFileNames(dir) if err != nil { return nil, err } @@ -158,7 +206,7 @@ func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID ) for _, name := range names { s.mu.Lock() - data, err := os.ReadFile(filepath.Join(s.dir, name)) + data, err := os.ReadFile(filepath.Join(dir, name)) s.mu.Unlock() if err != nil { continue @@ -182,27 +230,24 @@ func (s *FileSessionStore[State]) GetLatestSnapshot(_ context.Context, sessionID } // Fully decode only the winner. CreatedAt is preserved across rewrites, so a // concurrent rewrite of this row between scan and read still yields the - // right snapshot (with possibly fresher state). + // right snapshot (with possibly fresher state). A parse failure here is + // treated like a vanished row: report no tip rather than erroring. s.mu.Lock() - data, err := os.ReadFile(filepath.Join(s.dir, bestName)) + snap, _ := s.readAt(filepath.Join(dir, bestName)) s.mu.Unlock() - if err != nil { + if snap == nil { return nil, nil } - var snap exp.SessionSnapshot[State] - if err := json.Unmarshal(data, &snap); err != nil { - return nil, nil - } - return &snap, nil + return snap, nil } -// snapshotFileNames returns the names of the directory's snapshot files -// (non-directory *.json entries; writeLocked's ".*.tmp" temp files never -// match). Returns nil if the directory does not exist. The listing is not -// atomic with respect to concurrent writes; a snapshot that appears or -// disappears mid-scan may or may not be observed. -func (s *FileSessionStore[State]) snapshotFileNames() ([]string, error) { - entries, err := os.ReadDir(s.dir) +// snapshotFileNames returns the names of dir's snapshot files (non-directory +// *.json entries; writeAt's ".*.tmp" temp files never match). Returns nil if +// the directory does not exist. The listing is not atomic with respect to +// concurrent writes; a snapshot that appears or disappears mid-scan may or may +// not be observed. +func (s *FileSessionStore[State]) snapshotFileNames(dir string) ([]string, error) { + entries, err := os.ReadDir(dir) if err != nil { if errors.Is(err, os.ErrNotExist) { return nil, nil @@ -230,9 +275,14 @@ func (s *FileSessionStore[State]) OnSnapshotStatusChange(ctx context.Context, sn close(ch) return ch } + prefix, err := s.derivePrefix(ctx) + if err != nil { + close(ch) + return ch + } s.mu.Lock() - snap, err := s.readLocked(snapshotID) + snap, err := s.readAt(s.pathFor(prefix, snapshotID)) if err != nil || snap == nil { s.mu.Unlock() close(ch) @@ -246,31 +296,46 @@ func (s *FileSessionStore[State]) OnSnapshotStatusChange(ctx context.Context, sn return ch } -// readLocked reads and parses the snapshot file. Returns (nil, nil) if the +// derivePrefix resolves the per-call subdirectory snapshots live under by +// invoking the configured prefix function (if any) and sanitizing its result. +// Returns "" (the store root) when no function is configured or it yields an +// empty value. +func (s *FileSessionStore[State]) derivePrefix(ctx context.Context) (string, error) { + if s.prefixFn == nil { + return "", nil + } + return sanitizePrefix(s.prefixFn(ctx)) +} + +// readAt reads and parses the snapshot file at path. Returns (nil, nil) if the // file does not exist. Caller must hold s.mu. -func (s *FileSessionStore[State]) readLocked(snapshotID string) (*exp.SessionSnapshot[State], error) { - data, err := os.ReadFile(s.path(snapshotID)) +func (s *FileSessionStore[State]) readAt(path string) (*exp.SessionSnapshot[State], error) { + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, os.ErrNotExist) { return nil, nil } - return nil, fmt.Errorf("FileSessionStore: read %s: %w", snapshotID, err) + return nil, fmt.Errorf("FileSessionStore: read %s: %w", path, err) } var snap exp.SessionSnapshot[State] if err := json.Unmarshal(data, &snap); err != nil { - return nil, fmt.Errorf("FileSessionStore: unmarshal %s: %w", snapshotID, err) + return nil, fmt.Errorf("FileSessionStore: unmarshal %s: %w", path, err) } return &snap, nil } -// writeLocked atomically writes the snapshot to disk via a temp file + -// rename. Caller must hold s.mu. -func (s *FileSessionStore[State]) writeLocked(snap *exp.SessionSnapshot[State]) error { +// writeAt atomically writes snap to /.json via a temp file + +// rename, creating the prefix directory as needed. Caller must hold s.mu. +func (s *FileSessionStore[State]) writeAt(prefix string, snap *exp.SessionSnapshot[State]) error { + dir := s.prefixDir(prefix) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("FileSessionStore: create dir: %w", err) + } data, err := json.MarshalIndent(snap, "", " ") if err != nil { return fmt.Errorf("FileSessionStore: marshal: %w", err) } - f, err := os.CreateTemp(s.dir, snap.SnapshotID+".*.tmp") + f, err := os.CreateTemp(dir, snap.SnapshotID+".*.tmp") if err != nil { return fmt.Errorf("FileSessionStore: create temp: %w", err) } @@ -290,16 +355,45 @@ func (s *FileSessionStore[State]) writeLocked(snap *exp.SessionSnapshot[State]) if err := f.Close(); err != nil { return fmt.Errorf("FileSessionStore: close: %w", err) } - if err := os.Rename(tmpName, s.path(snap.SnapshotID)); err != nil { + if err := os.Rename(tmpName, s.pathFor(prefix, snap.SnapshotID)); err != nil { return fmt.Errorf("FileSessionStore: rename: %w", err) } return nil } -// path returns the on-disk path for a snapshot ID. The ID is assumed to have -// been validated by validateSnapshotID. -func (s *FileSessionStore[State]) path(snapshotID string) string { - return filepath.Join(s.dir, snapshotID+".json") +// pruneLocked enforces maxChain by walking the parentId chain back from the +// just-written snapshot and unlinking every row past the newest maxChain +// entries. Parents live in the same prefix directory, addressed directly by ID, +// so no scan is needed. A broken or cyclic chain stops the walk early (a visited +// set guards against cycles). Deletion is best-effort: a failed unlink leaves a +// stale row rather than failing the already-committed save. Caller holds s.mu +// and has checked maxChain > 0. +func (s *FileSessionStore[State]) pruneLocked(prefix string, tip *exp.SessionSnapshot[State]) { + chain := []string{tip.SnapshotID} + seen := map[string]bool{tip.SnapshotID: true} + for cur := tip; cur.ParentID != "" && !seen[cur.ParentID]; { + parent, err := s.readAt(s.pathFor(prefix, cur.ParentID)) + if err != nil || parent == nil { + break + } + seen[cur.ParentID] = true + chain = append(chain, cur.ParentID) + cur = parent + } + for _, oldID := range chain[min(s.maxChain, len(chain)):] { + _ = os.Remove(s.pathFor(prefix, oldID)) + } +} + +// prefixDir returns the on-disk directory snapshots under prefix are stored in. +func (s *FileSessionStore[State]) prefixDir(prefix string) string { + return filepath.Join(s.dir, prefix) +} + +// pathFor returns the on-disk path for a snapshot ID under prefix. Both are +// assumed validated (by validateSnapshotID and sanitizePrefix). +func (s *FileSessionStore[State]) pathFor(prefix, snapshotID string) string { + return filepath.Join(s.dir, prefix, snapshotID+".json") } // removeSub detaches a subscriber and closes its channel. @@ -329,23 +423,52 @@ func (s *FileSessionStore[State]) notifyLocked(snapshotID string, status exp.Sna } } +// sanitizePrefix turns a raw prefix (which may contain "/" to nest +// subdirectories) into a cleaned relative path under the store root, rejecting +// any value that could escape it. Empty and separator-only inputs yield "". +func sanitizePrefix(raw string) (string, error) { + if strings.Contains(raw, `\`) { + return "", fmt.Errorf("FileSessionStore: path prefix %q must use '/' separators", raw) + } + var segs []string + for _, seg := range strings.Split(raw, "/") { + if seg == "" { + continue // collapse empty segments (leading/trailing/double slash) + } + if err := validatePathSegment(seg); err != nil { + return "", fmt.Errorf("FileSessionStore: invalid path prefix %q: %w", raw, err) + } + segs = append(segs, seg) + } + return filepath.Join(segs...), nil +} + // validateSnapshotID rejects IDs that would escape the store directory or -// collide with reserved filenames. UUIDs (the default produced by an empty -// id) pass trivially. +// collide with reserved filenames; the ID is used directly as a file name. +// UUIDs (the default produced by an empty id) pass trivially. func validateSnapshotID(id string) error { - if id == "" { - return errors.New("FileSessionStore: snapshot ID is empty") + if err := validatePathSegment(id); err != nil { + return fmt.Errorf("FileSessionStore: invalid snapshot ID %q: %w", id, err) + } + return nil +} + +// validatePathSegment rejects a value that cannot serve as a single on-disk path +// component without risking directory escape or hidden/reserved-name collisions. +func validatePathSegment(s string) error { + if s == "" { + return errors.New("empty") } - if strings.ContainsAny(id, `/\`) || strings.Contains(id, "..") { - return fmt.Errorf("FileSessionStore: snapshot ID %q contains path separators", id) + if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") { + return errors.New("contains path separators") } - if strings.HasPrefix(id, ".") { - return fmt.Errorf("FileSessionStore: snapshot ID %q must not start with '.'", id) + if strings.HasPrefix(s, ".") { + return errors.New("must not start with '.'") } // Disallow NUL and control characters that some filesystems reject. - for _, r := range id { + for _, r := range s { if r < 0x20 { - return fmt.Errorf("FileSessionStore: snapshot ID %q contains control characters", id) + return errors.New("contains control characters") } } return nil diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index c356f96951..74d1380fb3 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -480,3 +480,244 @@ func TestFileSessionStore_GetLatestSnapshot_SkipsUnparseableFiles(t *testing.T) t.Errorf("expected the healthy row as tip, got %+v", tip) } } + +// TestFileSessionStore_MaxPersistedChainLength verifies that, with a retention +// window of n, each save unlinks the rows past the newest n in the snapshot's +// parentId chain while leaving the window (and session resolution) intact. +func TestFileSessionStore_MaxPersistedChainLength(t *testing.T) { + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir, WithMaxPersistedChainLength(2)) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + ctx := context.Background() + base := time.Now() + + // A linear chain s0 <- s1 <- s2 <- s3 <- s4, each created strictly after + // the last so recency is unambiguous. + ids := []string{"s0", "s1", "s2", "s3", "s4"} + parent := "" + for i, id := range ids { + createdAt := base.Add(time.Duration(i) * time.Second) + parentID := parent + if _, err := store.SaveSnapshot(ctx, id, + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{ + SessionID: "sess", + ParentID: parentID, + Status: exp.SnapshotStatusCompleted, + State: &exp.SessionState[testState]{Custom: testState{Counter: i}}, + CreatedAt: createdAt, + UpdatedAt: createdAt, + }, nil + }); err != nil { + t.Fatalf("SaveSnapshot(%q): %v", id, err) + } + parent = id + } + + // Only the two newest rows survive; everything older is pruned. + for _, gone := range []string{"s0", "s1", "s2"} { + snap, err := store.GetSnapshot(ctx, gone) + if err != nil { + t.Fatalf("GetSnapshot(%q): %v", gone, err) + } + if snap != nil { + t.Errorf("expected %q pruned, but it is still present", gone) + } + } + for _, kept := range []string{"s3", "s4"} { + snap, err := store.GetSnapshot(ctx, kept) + if err != nil { + t.Fatalf("GetSnapshot(%q): %v", kept, err) + } + if snap == nil { + t.Errorf("expected %q retained, got nil", kept) + } + } + + // Session resolution still finds the newest survivor. + latest, err := store.GetLatestSnapshot(ctx, "sess") + if err != nil { + t.Fatalf("GetLatestSnapshot: %v", err) + } + if latest == nil || latest.SnapshotID != "s4" { + t.Errorf("latest = %+v, want s4", latest) + } +} + +type prefixCtxKey struct{} + +func ctxWithPrefix(prefix string) context.Context { + return context.WithValue(context.Background(), prefixCtxKey{}, prefix) +} + +func prefixFromCtx(ctx context.Context) string { + v, _ := ctx.Value(prefixCtxKey{}).(string) + return v +} + +// TestFileSessionStore_PathPrefix verifies that a context-derived prefix scopes +// both writes and reads: a snapshot lands under the tenant subdirectory and is +// invisible to a different tenant, for both by-ID and by-session lookups. +func TestFileSessionStore_PathPrefix(t *testing.T) { + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir, WithSnapshotPathPrefix(prefixFromCtx)) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + ctxA := ctxWithPrefix("tenant-a") + ctxB := ctxWithPrefix("tenant-b") + + if _, err := store.SaveSnapshot(ctxA, "s1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: "sess", Status: exp.SnapshotStatusCompleted}, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + + // Written under the tenant subdirectory, not the store root. + if _, err := os.Stat(filepath.Join(dir, "tenant-a", "s1.json")); err != nil { + t.Errorf("expected tenant-a/s1.json on disk: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "s1.json")); !os.IsNotExist(err) { + t.Errorf("snapshot must not land in store root, stat err = %v", err) + } + + // Visible under the writing prefix. + got, err := store.GetSnapshot(ctxA, "s1") + if err != nil || got == nil { + t.Fatalf("GetSnapshot(ctxA): got=%v err=%v", got, err) + } + latestA, err := store.GetLatestSnapshot(ctxA, "sess") + if err != nil || latestA == nil || latestA.SnapshotID != "s1" { + t.Fatalf("GetLatestSnapshot(ctxA): got=%+v err=%v", latestA, err) + } + + // Isolated from a different prefix, by ID and by session. + other, err := store.GetSnapshot(ctxB, "s1") + if err != nil { + t.Fatalf("GetSnapshot(ctxB): %v", err) + } + if other != nil { + t.Errorf("tenant-b must not see tenant-a's snapshot, got %+v", other) + } + latestB, err := store.GetLatestSnapshot(ctxB, "sess") + if err != nil { + t.Fatalf("GetLatestSnapshot(ctxB): %v", err) + } + if latestB != nil { + t.Errorf("expected nil latest for tenant-b, got %+v", latestB) + } +} + +// TestFileSessionStore_PathPrefix_Nested verifies a prefix may nest multiple +// subdirectories via "/". +func TestFileSessionStore_PathPrefix_Nested(t *testing.T) { + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir, + WithSnapshotPathPrefix(func(context.Context) string { return "org-42/user-7" })) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + ctx := context.Background() + if _, err := store.SaveSnapshot(ctx, "s1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: "sess", Status: exp.SnapshotStatusCompleted}, nil + }); err != nil { + t.Fatalf("SaveSnapshot: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "org-42", "user-7", "s1.json")); err != nil { + t.Errorf("expected org-42/user-7/s1.json on disk: %v", err) + } + got, err := store.GetSnapshot(ctx, "s1") + if err != nil || got == nil { + t.Errorf("GetSnapshot through nested prefix: got=%v err=%v", got, err) + } +} + +// TestFileSessionStore_PathPrefix_Rejected verifies a prefix that would escape +// the store directory is rejected at call time rather than silently writing +// outside it. +func TestFileSessionStore_PathPrefix_Rejected(t *testing.T) { + for _, bad := range []string{"../escape", `a\b`, ".hidden", "ok/../escape"} { + t.Run(bad, func(t *testing.T) { + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir, + WithSnapshotPathPrefix(func(context.Context) string { return bad })) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + ctx := context.Background() + if _, err := store.SaveSnapshot(ctx, "s1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: "sess"}, nil + }); err == nil { + t.Error("SaveSnapshot: expected error for escaping prefix, got nil") + } + if _, err := store.GetSnapshot(ctx, "s1"); err == nil { + t.Error("GetSnapshot: expected error for escaping prefix, got nil") + } + if _, err := store.GetLatestSnapshot(ctx, "sess"); err == nil { + t.Error("GetLatestSnapshot: expected error for escaping prefix, got nil") + } + }) + } +} + +// TestFileSessionStore_OptionSetTwice verifies the construction options reject +// being set more than once, surfacing the error from NewFileSessionStore. +func TestFileSessionStore_OptionSetTwice(t *testing.T) { + dir := t.TempDir() + if _, err := NewFileSessionStore[testState](dir, + WithMaxPersistedChainLength(2), WithMaxPersistedChainLength(3)); err == nil { + t.Error("expected error setting max persisted chain length twice, got nil") + } + if _, err := NewFileSessionStore[testState](dir, + WithSnapshotPathPrefix(prefixFromCtx), WithSnapshotPathPrefix(prefixFromCtx)); err == nil { + t.Error("expected error setting snapshot path prefix twice, got nil") + } +} + +// TestFileSessionStore_MaxPersistedChainLength_Invalid verifies a retention +// window of 0 or less is rejected at construction rather than silently +// disabling pruning. +func TestFileSessionStore_MaxPersistedChainLength_Invalid(t *testing.T) { + dir := t.TempDir() + for _, n := range []int{0, -1} { + if _, err := NewFileSessionStore[testState](dir, WithMaxPersistedChainLength(n)); err == nil { + t.Errorf("WithMaxPersistedChainLength(%d): expected error, got nil", n) + } + } +} + +// TestFileSessionStore_MaxPersistedChainLength_One verifies a window of 1 is +// accepted and keeps only the latest snapshot, pruning each predecessor. +func TestFileSessionStore_MaxPersistedChainLength_One(t *testing.T) { + dir := t.TempDir() + store, err := NewFileSessionStore[testState](dir, WithMaxPersistedChainLength(1)) + if err != nil { + t.Fatalf("NewFileSessionStore: %v", err) + } + ctx := context.Background() + now := time.Now() + if _, err := store.SaveSnapshot(ctx, "s0", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: "sess", Status: exp.SnapshotStatusCompleted, CreatedAt: now, UpdatedAt: now}, nil + }); err != nil { + t.Fatalf("SaveSnapshot(s0): %v", err) + } + later := now.Add(time.Second) + if _, err := store.SaveSnapshot(ctx, "s1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{SessionID: "sess", ParentID: "s0", Status: exp.SnapshotStatusCompleted, CreatedAt: later, UpdatedAt: later}, nil + }); err != nil { + t.Fatalf("SaveSnapshot(s1): %v", err) + } + if snap, _ := store.GetSnapshot(ctx, "s0"); snap != nil { + t.Error("expected predecessor s0 pruned with window 1") + } + if snap, _ := store.GetSnapshot(ctx, "s1"); snap == nil { + t.Error("expected tip s1 retained with window 1") + } +} diff --git a/go/ai/exp/localstore/option.go b/go/ai/exp/localstore/option.go new file mode 100644 index 0000000000..a99e03fec1 --- /dev/null +++ b/go/ai/exp/localstore/option.go @@ -0,0 +1,92 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package localstore + +import ( + "context" + "errors" +) + +// fileStoreOptions holds the optional settings for a [FileSessionStore]. +type fileStoreOptions struct { + maxChain *int // Retention window (>= 1) when set; nil means unset. + prefixFn func(context.Context) string // Derives the per-call subdirectory. +} + +// FileStoreOption configures a [FileSessionStore] at construction. +// It applies only to [NewFileSessionStore]. +type FileStoreOption interface { + applyFileStore(*fileStoreOptions) error +} + +// applyFileStore merges o into opts, rejecting an option set more than once. +func (o *fileStoreOptions) applyFileStore(opts *fileStoreOptions) error { + if o.maxChain != nil { + if *o.maxChain < 1 { + return errors.New("max persisted chain length must be at least 1 (WithMaxPersistedChainLength)") + } + if opts.maxChain != nil { + return errors.New("cannot set max persisted chain length more than once (WithMaxPersistedChainLength)") + } + opts.maxChain = o.maxChain + } + + if o.prefixFn != nil { + if opts.prefixFn != nil { + return errors.New("cannot set snapshot path prefix more than once (WithSnapshotPathPrefix)") + } + opts.prefixFn = o.prefixFn + } + + return nil +} + +// WithMaxPersistedChainLength bounds how many snapshots a single conversation +// chain keeps on disk. Each save walks the new snapshot's parentId chain and +// unlinks every row older than the newest n, capping per-conversation disk use. +// +// n must be at least 1; [NewFileSessionStore] rejects 0 or a negative value. A +// window of 1 retains only the latest snapshot (each save prunes its +// predecessor). Omitting the option entirely (the default) leaves pruning +// disabled, and chains grow without bound. +// +// Snapshots are self-contained (each carries the full session state), so +// dropping an old ancestor only removes it as a resume point; every surviving +// row remains fully loadable. Retention follows parentId links, whereas +// [FileSessionStore.GetLatestSnapshot] resolves recency by CreatedAt, so pruning +// only ever removes rows reachable from the saved snapshot's own chain; a +// sibling branch (e.g. after a regenerate) is pruned independently when it is +// itself extended. +func WithMaxPersistedChainLength(n int) FileStoreOption { + return &fileStoreOptions{maxChain: &n} +} + +// WithSnapshotPathPrefix derives a per-call subdirectory from the operation's +// context, isolating snapshots by tenant: a snapshot written under one prefix is +// visible only to calls that derive the same prefix. A typical fn pulls a stable +// identity (e.g. an authenticated user or org ID) out of ctx. +// +// The returned value may contain "/" to nest several levels +// (e.g. "org-42/user-7"); empty and separator-only results place snapshots +// directly under the store root. The value must be stable for a given +// snapshot's lifetime, since every read recomputes it - derive it from stable +// identity, not from per-request state. A value that would escape the store +// directory (contains "..", a backslash, or a segment starting with ".") is +// rejected at call time. +func WithSnapshotPathPrefix(fn func(ctx context.Context) string) FileStoreOption { + return &fileStoreOptions{prefixFn: fn} +} From e5707bf69dfbbcfadb4f2cb5cd742f22ff6f2ac5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 22 Jun 2026 21:50:15 -0700 Subject: [PATCH 127/141] feat(go/exp): observe cross-process snapshot status changes in file store FileSessionStore.OnSnapshotStatusChange previously reflected only status changes written through the same store instance, so a detached turn running in one process could not be aborted by another process sharing the directory. Add an internal poller that re-reads subscribed snapshot files on an interval (default 1s, configurable via WithPollInterval; <=0 disables it) and delivers any change. The instant in-process fast path is kept for same-process delivery; both paths funnel through a single per-snapshot dedup gate, so a change is delivered exactly once regardless of which observes it first. The poller is started lazily on the first subscriber and stopped when the last unsubscribes, so a store with no subscriptions pays nothing. Polling (rather than fsnotify) matches the SnapshotSubscriber contract, adds no dependency, and is immune to the temp-file+rename inode swap. --- go/ai/exp/localstore/file.go | 183 +++++++++++++++++++++++++----- go/ai/exp/localstore/file_test.go | 94 +++++++++++++++ go/ai/exp/localstore/option.go | 23 ++++ 3 files changed, 272 insertions(+), 28 deletions(-) diff --git a/go/ai/exp/localstore/file.go b/go/ai/exp/localstore/file.go index 28236c0c5f..bc529a8964 100644 --- a/go/ai/exp/localstore/file.go +++ b/go/ai/exp/localstore/file.go @@ -47,11 +47,14 @@ import ( // caller does not have. That is why snapshots are grouped by prefix and kept // flat within it rather than nested under a per-session directory. // -// The store is safe for concurrent use within a single process, but does NOT -// coordinate with other processes sharing the directory: the last successful -// rename wins, and a reader may briefly observe a snapshot another process is -// still writing. [FileSessionStore.OnSnapshotStatusChange] likewise reflects -// only status changes made through this instance. +// The store is safe for concurrent use, and [FileSessionStore.OnSnapshotStatusChange] +// surfaces status changes written by other processes (or other store instances) +// sharing the directory by polling the snapshot files on an interval (see +// [WithPollInterval]); that cross-process visibility is what lets one process +// abort a detached turn another process is running. The store still does not +// provide cross-process transactions: the last successful rename wins. Each +// write is atomic (temp file + rename), so a concurrent reader sees either the +// old file or the new one, never a torn write. type FileSessionStore[State any] struct { // mu serializes the read-modify-write paths and the subscriber bookkeeping. // File I/O happens under the lock; this matches the simplicity of @@ -65,9 +68,36 @@ type FileSessionStore[State any] struct { // prefixFn, when set, derives the per-call subdirectory from context; see // [WithSnapshotPathPrefix]. prefixFn func(context.Context) string - subs map[string][]chan exp.SnapshotStatus + // poll is the interval at which the background poller re-reads subscribed + // snapshot files to detect cross-process status changes; <= 0 disables it. + poll time.Duration + subs map[string]*snapshotSubs + // pollCancel stops the background poller. It is non-nil exactly while the + // poller runs: while at least one subscription is active and poll > 0. + pollCancel context.CancelFunc } +// snapshotSubs holds the live subscribers to one snapshot plus the state the +// poller needs to surface its status changes. +type snapshotSubs struct { + chans []chan exp.SnapshotStatus + // path is the file the poller re-reads, captured at subscription time + // because the snapshot's prefix is derived from the subscriber's context + // and is not otherwise known on a by-poll re-read. + path string + // last is the status most recently delivered to chans (seeded at the first + // subscription). It is the single dedup gate shared by the in-process write + // path and the poller, so a change is delivered once regardless of which + // observes it first. + last exp.SnapshotStatus +} + +// defaultPollInterval is how often the poller re-reads subscribed snapshot +// files when no interval is configured. It sits well below the agent heartbeat +// interval so an operator-driven abort propagates promptly while the idle I/O +// cost stays negligible. +const defaultPollInterval = time.Second * 2 + // NewFileSessionStore creates a file-based snapshot store rooted at dir. // The directory is created (mode 0o700) if it does not already exist. // Returns an error if dir is empty, cannot be created, or an option is set @@ -90,11 +120,16 @@ func NewFileSessionStore[State any](dir string, opts ...FileStoreOption) (*FileS if resolved.maxChain != nil { maxChain = *resolved.maxChain } + poll := defaultPollInterval + if resolved.poll != nil { + poll = *resolved.poll + } return &FileSessionStore[State]{ dir: dir, maxChain: maxChain, prefixFn: resolved.prefixFn, - subs: make(map[string][]chan exp.SnapshotStatus), + poll: poll, + subs: make(map[string]*snapshotSubs), }, nil } @@ -159,9 +194,7 @@ func (s *FileSessionStore[State]) SaveSnapshot( if err := s.writeAt(prefix, next); err != nil { return nil, err } - if existing == nil || existing.Status != next.Status { - s.notifyLocked(id, next.Status) - } + s.maybeNotifyLocked(id, next.Status) if s.maxChain > 0 { s.pruneLocked(prefix, next) } @@ -265,10 +298,18 @@ func (s *FileSessionStore[State]) snapshotFileNames(dir string) ([]string, error } // OnSnapshotStatusChange subscribes to status changes for a snapshot. The -// returned channel yields the current status (if any) and any subsequent -// changes triggered by calls on this store instance, until ctx is cancelled. -// Changes made by other processes writing to the same directory are not -// observed. +// returned channel yields the status at subscription time and every subsequent +// change until ctx is cancelled. A change is surfaced immediately when written +// through this store instance, and within one poll interval when written by +// another process sharing the directory (see [WithPollInterval]); polling +// re-reads the file the snapshot was found under at subscription time. If the +// snapshot does not exist at subscription time, the channel is closed without +// yielding a value. +// +// Values are level-triggered: the latest status is always delivered, but a slow +// reader may skip intermediate values, and a subscriber that joins concurrently +// with a change may observe that status twice. Treat a received value as "the +// status is now X", not "X just happened once". func (s *FileSessionStore[State]) OnSnapshotStatusChange(ctx context.Context, snapshotID string) <-chan exp.SnapshotStatus { ch := make(chan exp.SnapshotStatus, 1) if err := validateSnapshotID(snapshotID); err != nil { @@ -280,16 +321,25 @@ func (s *FileSessionStore[State]) OnSnapshotStatusChange(ctx context.Context, sn close(ch) return ch } + path := s.pathFor(prefix, snapshotID) s.mu.Lock() - snap, err := s.readAt(s.pathFor(prefix, snapshotID)) + snap, err := s.readAt(path) if err != nil || snap == nil { s.mu.Unlock() close(ch) return ch } ch <- snap.Status - s.subs[snapshotID] = append(s.subs[snapshotID], ch) + sub := s.subs[snapshotID] + if sub == nil { + // First subscriber: seed the dedup baseline with the current status and + // remember the path so the poller can re-read it. + sub = &snapshotSubs{path: path, last: snap.Status} + s.subs[snapshotID] = sub + } + sub.chans = append(sub.chans, ch) + s.startPollerLocked() s.mu.Unlock() context.AfterFunc(ctx, func() { s.removeSub(snapshotID, ch) }) @@ -396,33 +446,110 @@ func (s *FileSessionStore[State]) pathFor(prefix, snapshotID string) string { return filepath.Join(s.dir, prefix, snapshotID+".json") } -// removeSub detaches a subscriber and closes its channel. +// removeSub detaches a subscriber and closes its channel, dropping the +// snapshot's bookkeeping (and stopping the poller) once no subscribers remain. func (s *FileSessionStore[State]) removeSub(snapshotID string, ch chan exp.SnapshotStatus) { s.mu.Lock() defer s.mu.Unlock() - subs := s.subs[snapshotID] - i := slices.Index(subs, ch) + sub := s.subs[snapshotID] + if sub == nil { + return + } + i := slices.Index(sub.chans, ch) if i < 0 { return } - subs = slices.Delete(subs, i, i+1) - if len(subs) == 0 { + sub.chans = slices.Delete(sub.chans, i, i+1) + if len(sub.chans) == 0 { delete(s.subs, snapshotID) - } else { - s.subs[snapshotID] = subs + } + if len(s.subs) == 0 { + s.stopPollerLocked() } close(ch) } -// notifyLocked publishes status to all live subscribers of snapshotID. -// Caller must hold s.mu. A slow subscriber may miss intermediate values, but -// the latest value is always delivered (see [coalesceSend]). -func (s *FileSessionStore[State]) notifyLocked(snapshotID string, status exp.SnapshotStatus) { - for _, ch := range s.subs[snapshotID] { +// maybeNotifyLocked fans status out to snapshotID's subscribers, but only when +// it differs from the value they last saw. It is the shared dedup gate for both +// the in-process write path ([SaveSnapshot]) and the cross-process poller, so +// whichever observes a change first delivers it once and the other is a no-op. +// Caller must hold s.mu. A slow subscriber may miss intermediate values, but the +// latest is always delivered (see [coalesceSend]). +func (s *FileSessionStore[State]) maybeNotifyLocked(snapshotID string, status exp.SnapshotStatus) { + sub := s.subs[snapshotID] + if sub == nil || sub.last == status { + return + } + sub.last = status + for _, ch := range sub.chans { coalesceSend(ch, status) } } +// startPollerLocked launches the background poller if it is not already running +// and polling is enabled. Caller must hold s.mu. +func (s *FileSessionStore[State]) startPollerLocked() { + if s.pollCancel != nil || s.poll <= 0 { + return + } + ctx, cancel := context.WithCancel(context.Background()) + s.pollCancel = cancel + go s.pollLoop(ctx, s.poll) +} + +// stopPollerLocked signals the poller to exit. Caller must hold s.mu. It does +// not wait: the goroutine observes the cancellation and returns promptly, and +// the shared dedup gate keeps a briefly-overlapping successor poller correct. +func (s *FileSessionStore[State]) stopPollerLocked() { + if s.pollCancel == nil { + return + } + s.pollCancel() + s.pollCancel = nil +} + +// pollLoop re-reads every subscribed snapshot on each tick until ctx is +// cancelled, delivering status changes written by other processes (or other +// store instances) sharing the directory - the only way such changes reach +// subscribers, since a cross-process write never runs this instance's +// in-process notification. +func (s *FileSessionStore[State]) pollLoop(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.pollOnce() + } + } +} + +// pollOnce re-reads each subscribed snapshot's file and delivers any status +// change through the shared dedup gate. It snapshots the subscription set under +// the lock, then reads each file under its own lock acquisition, so a tick never +// blocks a write for longer than a single file read and the read pairs +// atomically with the dedup check (ruling out delivering a status the file no +// longer holds). A read error or vanished file is skipped. +func (s *FileSessionStore[State]) pollOnce() { + type target struct{ id, path string } + s.mu.Lock() + targets := make([]target, 0, len(s.subs)) + for id, sub := range s.subs { + targets = append(targets, target{id: id, path: sub.path}) + } + s.mu.Unlock() + + for _, t := range targets { + s.mu.Lock() + if snap, err := s.readAt(t.path); err == nil && snap != nil { + s.maybeNotifyLocked(t.id, snap.Status) + } + s.mu.Unlock() + } +} + // sanitizePrefix turns a raw prefix (which may contain "/" to nest // subdirectories) into a cleaned relative path under the store root, rejecting // any value that could escape it. Empty and separator-only inputs yield "". diff --git a/go/ai/exp/localstore/file_test.go b/go/ai/exp/localstore/file_test.go index 74d1380fb3..5b9697c4ce 100644 --- a/go/ai/exp/localstore/file_test.go +++ b/go/ai/exp/localstore/file_test.go @@ -392,6 +392,100 @@ func TestFileSessionStore(t *testing.T) { }) } +// recvStatus waits up to timeout for the next status on ch, failing the test on +// a timeout or an unexpectedly closed channel. +func recvStatus(t *testing.T, ch <-chan exp.SnapshotStatus, timeout time.Duration) exp.SnapshotStatus { + t.Helper() + select { + case s, ok := <-ch: + if !ok { + t.Fatal("status channel closed unexpectedly") + } + return s + case <-time.After(timeout): + t.Fatal("timeout waiting for status") + return "" + } +} + +// TestFileSessionStore_CrossProcessStatusChange verifies that a status change +// written through one store instance is observed by a subscriber on a separate +// instance over the same directory - the cross-process case that backs aborting +// a detached turn from a different process. A second *FileSessionStore stands in +// for the other process. +func TestFileSessionStore_CrossProcessStatusChange(t *testing.T) { + dir := t.TempDir() + writer, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore (writer): %v", err) + } + // A short poll interval keeps the test fast without changing the behavior. + watcher, err := NewFileSessionStore[testState](dir, WithPollInterval(5*time.Millisecond)) + if err != nil { + t.Fatalf("NewFileSessionStore (watcher): %v", err) + } + + if _, err := writer.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusPending}, nil + }); err != nil { + t.Fatalf("seed pending: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := watcher.OnSnapshotStatusChange(ctx, "snap-1") + + if got := recvStatus(t, ch, time.Second); got != exp.SnapshotStatusPending { + t.Fatalf("initial status = %q, want %q", got, exp.SnapshotStatusPending) + } + + // Abort through the writer instance; the watcher must see it via polling. + if status := abortViaSave(t, writer, "snap-1"); status != exp.SnapshotStatusAborted { + t.Fatalf("abort via writer: status = %q, want %q", status, exp.SnapshotStatusAborted) + } + if got := recvStatus(t, ch, 2*time.Second); got != exp.SnapshotStatusAborted { + t.Fatalf("cross-process status = %q, want %q", got, exp.SnapshotStatusAborted) + } +} + +// TestFileSessionStore_PollIntervalDisabled verifies that WithPollInterval(0) +// turns off cross-process polling: the subscriber still gets the seed value but +// never sees a change written through another instance. +func TestFileSessionStore_PollIntervalDisabled(t *testing.T) { + dir := t.TempDir() + writer, err := NewFileSessionStore[testState](dir) + if err != nil { + t.Fatalf("NewFileSessionStore (writer): %v", err) + } + watcher, err := NewFileSessionStore[testState](dir, WithPollInterval(0)) + if err != nil { + t.Fatalf("NewFileSessionStore (watcher): %v", err) + } + + if _, err := writer.SaveSnapshot(context.Background(), "snap-1", + func(_ *exp.SessionSnapshot[testState]) (*exp.SessionSnapshot[testState], error) { + return &exp.SessionSnapshot[testState]{Status: exp.SnapshotStatusPending}, nil + }); err != nil { + t.Fatalf("seed pending: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := watcher.OnSnapshotStatusChange(ctx, "snap-1") + + if got := recvStatus(t, ch, time.Second); got != exp.SnapshotStatusPending { + t.Fatalf("initial status = %q, want %q", got, exp.SnapshotStatusPending) + } + + abortViaSave(t, writer, "snap-1") + select { + case got := <-ch: + t.Fatalf("with polling disabled, unexpectedly observed status %q", got) + case <-time.After(150 * time.Millisecond): + } +} + // TestFileSessionStore_FinishReasonPersistsAcrossReopen verifies that a // snapshot's finish reason survives the disk round-trip: a second store // opened on the same directory (as after a process restart) reads it back. diff --git a/go/ai/exp/localstore/option.go b/go/ai/exp/localstore/option.go index a99e03fec1..2bb41da2e0 100644 --- a/go/ai/exp/localstore/option.go +++ b/go/ai/exp/localstore/option.go @@ -19,12 +19,14 @@ package localstore import ( "context" "errors" + "time" ) // fileStoreOptions holds the optional settings for a [FileSessionStore]. type fileStoreOptions struct { maxChain *int // Retention window (>= 1) when set; nil means unset. prefixFn func(context.Context) string // Derives the per-call subdirectory. + poll *time.Duration // Cross-process poll interval when set; nil means unset. } // FileStoreOption configures a [FileSessionStore] at construction. @@ -52,6 +54,13 @@ func (o *fileStoreOptions) applyFileStore(opts *fileStoreOptions) error { opts.prefixFn = o.prefixFn } + if o.poll != nil { + if opts.poll != nil { + return errors.New("cannot set poll interval more than once (WithPollInterval)") + } + opts.poll = o.poll + } + return nil } @@ -90,3 +99,17 @@ func WithMaxPersistedChainLength(n int) FileStoreOption { func WithSnapshotPathPrefix(fn func(ctx context.Context) string) FileStoreOption { return &fileStoreOptions{prefixFn: fn} } + +// WithPollInterval sets how often the store re-reads subscribed snapshot files +// to surface status changes that other processes (or other store instances) +// sharing the directory write through [FileSessionStore.OnSnapshotStatusChange]. +// That cross-process visibility is what lets one process observe an abort (or +// any status change) another process commits, for example to stop a detached +// turn it is running. +// +// The default is one second. A value <= 0 disables cross-process polling: +// subscriptions then observe only changes written through this same store +// instance. +func WithPollInterval(d time.Duration) FileStoreOption { + return &fileStoreOptions{poll: &d} +} From 4c1e89e8d038708aa71e44fd516ad830d3327f98 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 22 Jun 2026 22:39:09 -0700 Subject: [PATCH 128/141] Update README.md --- go/README.md | 224 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) diff --git a/go/README.md b/go/README.md index 7884360937..356cfd409d 100644 --- a/go/README.md +++ b/go/README.md @@ -70,6 +70,228 @@ go run main.go --- +## Agents + +Agents are Genkit's primitive for multi-turn, stateful conversations. An agent owns the per-turn loop (render the prompt, append history, call the model, stream the reply) and the conversation's session state, so your code sends messages and reads results instead of re-threading history on every call. + +Beyond a plain chat loop, agents give you: + +- **Managed session state** that persists across turns, with typed custom state of your own. +- **Snapshots** written at the end of every successful turn, so a conversation can be resumed later by session or snapshot ID. +- **Background execution** via `Detach`: hand a long-running turn to the server, walk away, and poll, resume, or abort it later. +- **One definition, many transports**: the same agent runs in-process (`RunText`, `Connect`) or over HTTP, one turn per request. + +The agent API is experimental: it lives in `github.com/firebase/genkit/go/ai/exp` (aliased `aix` in the snippets below) and may change in a minor release. + +### Define an Agent + +The shortest path is a prompt-backed agent with an inline prompt and a session store. `aix.FromInline` declares the prompt right next to the agent; the store persists each turn so the conversation can resume later: + +```go +import ( + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" + "github.com/firebase/genkit/go/genkit" +) + +chatAgent := genkit.DefineAgent(g, "chat", + aix.FromInline( + ai.WithModelName("googleai/gemini-flash-latest"), + ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), + ), + aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), +) + +// Single turn: RunText drives the whole connection lifecycle for you. +out, _ := chatAgent.RunText(ctx, "What's the best way to learn Go?") +fmt.Println(out.Message.Text()) +``` + +The `State` type parameter is inferred from the typed options (`aix.WithSessionStore`, `aix.WithStateTransform`), so the explicit `DefineAgent[State]` is only needed when no typed option is supplied. + +[See full example](samples/basic-agents) + +### Multi-Turn Conversations + +`Connect` opens a streaming session you drive turn by turn: send a message, iterate chunks until `TurnEnd`, then send the next one. The agent carries the history between turns. `Output` ends the conversation and returns the final result: + +```go +conn, _ := chatAgent.Connect(ctx) + +conn.SendText("What is Go's concurrency model?") +for chunk, err := range conn.Receive() { + if err != nil { + log.Fatal(err) + } + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) // stream tokens as they arrive + } + if chunk.TurnEnd != nil { + break // turn complete, ready for the next input + } +} + +conn.SendText("Show me an example with goroutines.") +// ... iterate conn.Receive() again ... + +out, _ := conn.Output() // closes input, drains, returns the final AgentOutput +fmt.Println(out.Message.Text()) +``` + +[See full example](samples/basic-agents) + +### Load the Prompt from a File + +`aix.FromPrompt` backs the agent with a prompt already in the registry, including one loaded from a `.prompt` file. Prompt authors can tune the model, config, and template without touching the Go wiring. The agent and its prompt share a name: + +```yaml +# prompts/chat.prompt +--- +model: googleai/gemini-flash-latest +input: + schema: ChatInput +--- +{{role "system"}} +You are {{personality}}. Keep responses concise. +``` + +```go +type ChatInput struct { + Personality string `json:"personality"` +} + +// Register the schema so the .prompt file can reference it by name. +genkit.DefineSchemaFor[ChatInput](g) + +// FromPrompt's argument is the default input rendered on every turn. +chatAgent := genkit.DefineAgent(g, "chat", + aix.FromPrompt(ChatInput{Personality: "a Michelin-starred chef"}), + aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), +) +``` + +[See full example](samples/basic-agents) + +### Custom Turn Loops + +When the prompt-backed loop isn't enough (custom models per turn, pre/post processing, bespoke tool plumbing), `DefineCustomAgent` hands you the turn body. You still get managed session state, snapshots, and the detach lifecycle for free. A typed `State` parameter carries structured state across turns, and mutating it with `UpdateCustom` streams the delta to the client automatically: + +```go +type ChatState struct { + TopicsDiscussed []string `json:"topicsDiscussed"` +} + +chatAgent := genkit.DefineCustomAgent(g, "chat", + func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[ChatState]) (*aix.AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { + for chunk, err := range genkit.GenerateStream(ctx, g, + ai.WithModelName("googleai/gemini-flash-latest"), + ai.WithMessages(sess.Messages()...), // the history is yours to manage + ) { + if err != nil { + return nil, err + } + if chunk.Done { + sess.AddMessages(chunk.Response.Message) + if input.Message != nil { + sess.UpdateCustom(func(s ChatState) ChatState { + s.TopicsDiscussed = append(s.TopicsDiscussed, input.Message.Text()) + return s + }) + } + // Report how the turn ended so the framework can forward it + // on the TurnEnd chunk and persist it on the snapshot. + return &aix.TurnResult{ + FinishReason: aix.AgentFinishReason(chunk.Response.FinishReason), + }, nil + } + resp.SendModelChunk(chunk.Chunk) // stream tokens to the client + } + return nil, nil + }) + if err != nil { + return nil, err + } + return sess.Result(), nil + }, + aix.WithSessionStore(localstore.NewInMemorySessionStore[ChatState]()), +) +``` + +[See full example](samples/basic-agents) + +### Persist and Resume + +With a session store configured, every successful turn writes a snapshot. The caller only needs the `SessionID` from a previous result to pick the conversation back up: + +```go +first, _ := chatAgent.RunText(ctx, "My name is Alex and I'm planning a trip to Japan.") + +// Later, in another request or process: resume from the latest snapshot. +second, _ := chatAgent.RunText(ctx, "What is my name?", + aix.WithSessionID[any](first.SessionID)) +fmt.Println(second.Message.Text()) // "Your name is Alex." +``` + +Resume from one specific point in history with `aix.WithSnapshotID`, or skip the server store entirely and round-trip the state yourself with `aix.WithState` (the conversation's identity travels inside the state object). + +[See full example](samples/basic-agents) + +### Background Agents + +`Detach` hands the rest of the work to the server and closes the connection promptly with a pending snapshot ID. The agent keeps processing in the background on a context decoupled from the client's, so a long task survives the caller walking away: + +```go +conn, _ := chatAgent.Connect(ctx) +conn.SendText("Draft a detailed two-week Japan itinerary.") +conn.Detach() // server takes ownership of the remaining work + +out, _ := conn.Output() // returns immediately; FinishReason is "detached" +snapshotID := out.SnapshotID + +// Later: poll the snapshot, then resume once it has finalized. +snap, _ := chatAgent.GetSnapshot(ctx, snapshotID) +switch snap.Status { +case aix.SnapshotStatusPending: // still working +case aix.SnapshotStatusCompleted: // snap.State holds the final state; resume it +case aix.SnapshotStatusFailed: // snap.Error holds the structured failure +} + +// Or stop it early; the runtime observes the abort and cancels the work. +chatAgent.AbortSnapshot(ctx, snapshotID) +``` + +Detach requires a store that implements `SnapshotSubscriber` (both bundled local stores do). A detached turn refreshes a heartbeat while it runs, so a crashed worker surfaces as `expired` instead of orphaning the conversation forever. + +[See full example](samples/basic-agents) + +### Serve Agents over HTTP + +An `Agent` is an `api.BidiAction`, so it serves over HTTP one turn per request. The `genkit/exp` package lays out a default route surface for every registered agent, including the snapshot companion endpoints for store-backed agents: + +```go +import ( + genkitx "github.com/firebase/genkit/go/genkit/exp" + "github.com/firebase/genkit/go/plugins/server" +) + +mux := http.NewServeMux() +for _, r := range genkitx.AllAgentRoutes(g) { + mux.HandleFunc(r.Pattern(), r.Handler()) +} +// POST /agents/chat one turn per request (?stream=true for SSE) +// POST /agents/chat/getSnapshot read a snapshot by ID +// POST /agents/chat/abortSnapshot abort background work +log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +``` + +A client starts a conversation by POSTing a turn, then continues it by sending the returned `sessionId` in the request's `init` field. Agents with no store return the full state instead and the client round-trips it, so stateless and store-backed agents deploy side by side. + +[See full example](samples/basic-agents-server) + +--- + ## Features Genkit Go gives you everything you need to build AI applications with confidence. @@ -687,6 +909,8 @@ Explore working examples to see Genkit in action: | [basic](samples/basic) | Simple text generation with streaming | | [basic-structured](samples/basic-structured) | Typed JSON output with `GenerateData` and `GenerateDataStream` | | [basic-prompts](samples/basic-prompts) | Prompt templates with Handlebars and `.prompt` files | +| [basic-agents](samples/basic-agents) | Multi-turn agents (inline, prompt-file, and custom-loop) with snapshots and background detach | +| [basic-agents-server](samples/basic-agents-server) | Serving store-backed and stateless agents over HTTP | | [intermediate-interrupts](samples/intermediate-interrupts) | Human-in-the-loop with tool interrupts | | [basic-middleware/retry-fallback](samples/basic-middleware/retry-fallback) | Composing `Retry` and `Fallback` middleware | | [basic-middleware/filesystem](samples/basic-middleware/filesystem) | Scoped filesystem tools for the model | From 926875918d9dfabbbabf858744f4bda42b11eb1d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 23 Jun 2026 08:50:24 -0700 Subject: [PATCH 129/141] feat(go/exp): add session-flow middleware package and context seeding - ai/exp: add WithContextFunc option, an ArtifactStore context accessor, and ref.go - genkit: seed the Genkit instance into every agent turn's context so middleware can resolve actions via FromContext - plugins/middleware/exp: new experimental middleware package (agents, artifacts, inject, plugin) with tests - samples/basic-agents: split agents into per-file definitions and add an orchestrator that delegates to sub-agents through the agents middleware - go/.gitignore: ignore the compiled basic-agents sample binary --- go/.gitignore | 3 + go/ai/exp/agent.go | 6 + go/ai/exp/option.go | 26 + go/ai/exp/ref.go | 49 ++ go/ai/exp/session.go | 23 + go/genkit/genkit.go | 16 +- go/plugins/middleware/exp/agents.go | 504 +++++++++++++++++ go/plugins/middleware/exp/agents_test.go | 574 ++++++++++++++++++++ go/plugins/middleware/exp/artifacts.go | 209 +++++++ go/plugins/middleware/exp/artifacts_test.go | 219 ++++++++ go/plugins/middleware/exp/helpers_test.go | 43 ++ go/plugins/middleware/exp/inject.go | 90 +++ go/plugins/middleware/exp/plugin.go | 49 ++ go/samples/basic-agents/chef.go | 45 ++ go/samples/basic-agents/coder.go | 67 +++ go/samples/basic-agents/main.go | 140 +---- go/samples/basic-agents/orchestrator.go | 90 +++ go/samples/basic-agents/pirate.go | 39 ++ 18 files changed, 2079 insertions(+), 113 deletions(-) create mode 100644 go/.gitignore create mode 100644 go/ai/exp/ref.go create mode 100644 go/plugins/middleware/exp/agents.go create mode 100644 go/plugins/middleware/exp/agents_test.go create mode 100644 go/plugins/middleware/exp/artifacts.go create mode 100644 go/plugins/middleware/exp/artifacts_test.go create mode 100644 go/plugins/middleware/exp/helpers_test.go create mode 100644 go/plugins/middleware/exp/inject.go create mode 100644 go/plugins/middleware/exp/plugin.go create mode 100644 go/samples/basic-agents/chef.go create mode 100644 go/samples/basic-agents/coder.go create mode 100644 go/samples/basic-agents/orchestrator.go create mode 100644 go/samples/basic-agents/pirate.go diff --git a/go/.gitignore b/go/.gitignore new file mode 100644 index 0000000000..f32f2e12ef --- /dev/null +++ b/go/.gitignore @@ -0,0 +1,3 @@ +# Compiled sample binary produced by `go build ./samples/basic-agents` +# from this directory. It is a build artifact, not source. +/basic-agents diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 6ff830fa67..e07c9b039d 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -603,6 +603,12 @@ func NewCustomAgent[State any]( outCh chan<- *AgentStreamChunk, ) (*AgentOutput[State], error) { ctx = core.WithFlowContext(ctx, name) + // Apply any context decorators (e.g. the genkit package seeding its + // instance) before the runtime derives the per-turn work context, so + // the decorated values reach each turn's prompt, tools, and middleware. + if cfg.contextFunc != nil { + ctx = cfg.contextFunc(ctx) + } rt, err := newAgentRuntime(ctx, name, cfg, in, inCh, outCh) if err != nil { // Init failures (a rejected init payload, a failed diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 617f856221..0855dcb73b 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -46,6 +46,7 @@ type agentOptions[State any] struct { store SessionStore[State] transform StateTransform[State] description string + contextFunc func(context.Context) context.Context } func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { @@ -67,6 +68,18 @@ func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { } opts.description = o.description } + if o.contextFunc != nil { + // Context decorators compose rather than conflict: each WithContextFunc + // adds a layer applied in registration order. This lets the genkit + // package seed its instance (see genkit.DefineAgent) while applications + // add their own decorators on the same agent. + if prev := opts.contextFunc; prev != nil { + next := o.contextFunc + opts.contextFunc = func(ctx context.Context) context.Context { return next(prev(ctx)) } + } else { + opts.contextFunc = o.contextFunc + } + } return nil } @@ -91,6 +104,19 @@ func WithDescription[State any](description string) AgentOption[State] { return &agentOptions[State]{description: description} } +// WithContextFunc registers a function that decorates the context for each +// agent invocation, applied once before the turn loop runs so the returned +// context flows to the prompt, tools, and middleware of every turn. +// +// The genkit package uses it to seed the [genkit.Genkit] instance (retrievable +// with genkit.FromContext) so middleware can resolve and run other actions +// without direct registry access. Applications may also use it to attach +// invocation-scoped values (e.g. request identity). Multiple decorators compose +// in registration order. +func WithContextFunc[State any](fn func(context.Context) context.Context) AgentOption[State] { + return &agentOptions[State]{contextFunc: fn} +} + // --- InvocationOption --- // InvocationOption configures an agent invocation (StreamBidi, Run, or RunText). diff --git a/go/ai/exp/ref.go b/go/ai/exp/ref.go new file mode 100644 index 0000000000..d037c50e04 --- /dev/null +++ b/go/ai/exp/ref.go @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +// AgentRef refers to an agent by name, optionally carrying a description. It is +// the agent analog of [ai.ModelRef] / [ai.ToolRef]: a small, JSON-serializable +// value that names an agent for resolution against a registry. Like those, it +// resolves by name (the path the Dev UI, HTTP serving, and ListAgents all use), +// so the referenced agent must be registered wherever the ref is consumed. +// +// Build one by name with a struct literal, or derive it from an agent value +// with [Agent.Ref], which fills in the name and description so callers need not +// restate either: +// +// aix.AgentRef{Name: "researcher"} +// coderAgent.Ref() +type AgentRef struct { + // Name identifies the agent, resolved as /agent/. Required. + Name string `json:"name"` + // Description is a human-readable description used by consumers that list + // agents (e.g. the agents middleware's system prompt). [Agent.Ref] fills it + // from the agent's descriptor. Optional. + Description string `json:"description,omitempty"` +} + +// Ref returns an [AgentRef] for this agent, capturing its name and description +// so callers can reference it without restating either, and without a name +// string that can drift from the agent. Resolution remains by name, so the +// agent must be registered (as [DefineAgent] does) wherever the ref is used. +func (a *Agent[State]) Ref() AgentRef { + return AgentRef{ + Name: a.Name(), + Description: a.Desc().Description, + } +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 44bd98783c..ca670d8f74 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -475,3 +475,26 @@ func SessionFromContext[State any](ctx context.Context) *Session[State] { session, _ := sessionCtxKey.FromContext(ctx).(*Session[State]) return session } + +// ArtifactStore is the State-agnostic view of a session's artifact collection. +// Every [Session] satisfies it regardless of its State type, since artifact +// operations do not touch custom state. Middleware and tools that work with +// artifacts without knowing the agent's State type use it via +// [ArtifactStoreFromContext], where [SessionFromContext] cannot help because it +// requires the concrete State. +type ArtifactStore interface { + // Artifacts returns a snapshot of the session's current artifacts. + Artifacts() []*Artifact + // AddArtifacts adds artifacts, replacing any existing artifact of the same + // name. + AddArtifacts(artifacts ...*Artifact) +} + +// ArtifactStoreFromContext returns the active session's artifacts as a +// State-agnostic [ArtifactStore], or nil if there is no active session in ctx. +// Unlike [SessionFromContext] it does not require knowing the session's State +// type, so it is the accessor for middleware and tools. +func ArtifactStoreFromContext(ctx context.Context) ArtifactStore { + store, _ := sessionCtxKey.FromContext(ctx).(ArtifactStore) + return store +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index b541a1c91a..310753565a 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -41,13 +41,25 @@ import ( var genkitCtxKey = base.NewContextKey[*Genkit]() // FromContext returns the [*Genkit] instance stored in the context. -// This is set automatically by [Generate] and related functions. +// This is set automatically by [Generate] and related functions, and by +// agents defined via [DefineAgent] / [DefineCustomAgent] for each turn. // Middleware implementations can use this to access the Genkit instance // during generation. func FromContext(ctx context.Context) *Genkit { return genkitCtxKey.FromContext(ctx) } +// seedGenkitContext returns an agent option that seeds g into each agent +// invocation's context, so middleware and other code can retrieve it via +// [FromContext] during the agent's turns, just as [Generate] seeds it. Agents +// run below the genkit layer (on the registry alone), so without this the +// instance would be absent from an agent's turn context. +func seedGenkitContext[State any](g *Genkit) aix.AgentOption[State] { + return aix.WithContextFunc[State](func(ctx context.Context) context.Context { + return genkitCtxKey.NewContext(ctx, g) + }) +} + // Genkit encapsulates a Genkit instance, providing access to its registry, // configuration, and core functionalities. It serves as the central hub for // defining and managing Genkit resources like flows, models, tools, and prompts. @@ -469,6 +481,7 @@ func DefineAgent[State any]( source aix.AgentSource, opts ...aix.AgentOption[State], ) *aix.Agent[State] { + opts = append(opts, seedGenkitContext[State](g)) return aix.DefineAgent(g.reg, name, source, opts...) } @@ -538,6 +551,7 @@ func DefineCustomAgent[State any]( fn aix.AgentFunc[State], opts ...aix.AgentOption[State], ) *aix.Agent[State] { + opts = append(opts, seedGenkitContext[State](g)) return aix.DefineCustomAgent(g.reg, name, fn, opts...) } diff --git a/go/plugins/middleware/exp/agents.go b/go/plugins/middleware/exp/agents.go new file mode 100644 index 0000000000..889b53250a --- /dev/null +++ b/go/plugins/middleware/exp/agents.go @@ -0,0 +1,504 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" +) + +// agentsMarker tags the system prompt part injected by this middleware. The +// listing is constant for a given configuration, so it is injected once and +// matched (no-op) on later tool-loop iterations. +const agentsMarker = "agents-instructions" + +// defaultToolPrefix is the prefix applied to generated delegation tool names +// when [Agents.ToolPrefix] is unset (tools become delegate_to_). +const defaultToolPrefix = "delegate_to" + +// ArtifactStrategy controls how a sub-agent's artifacts are surfaced back to the +// orchestrator by the [Agents] middleware. +type ArtifactStrategy string + +const ( + // ArtifactStrategyInline includes artifact content in the delegation tool + // result so the orchestrator model can see it, and also merges artifacts + // into the parent session. This is the default. + ArtifactStrategyInline ArtifactStrategy = "inline" + // ArtifactStrategySession merges artifacts into the parent session only; the + // tool result names the artifacts but omits their content. Pair it with the + // [Artifacts] middleware so the model can read/write session artifacts. + ArtifactStrategySession ArtifactStrategy = "session" +) + +// resolveAgent looks the agent up by name through g. Resolution goes through +// the Genkit instance (the sanctioned path for third-party middleware) rather +// than the registry directly. +func resolveAgent(g *genkit.Genkit, ref aix.AgentRef) (api.BidiAction, error) { + if g == nil { + return nil, fmt.Errorf("no Genkit instance on the context (the agents middleware must run within genkit.Generate or a genkit-defined agent)") + } + action := genkit.LookupAction(g, "/agent/"+ref.Name) + if action == nil { + return nil, fmt.Errorf("agent %q not found in registry", ref.Name) + } + agent, ok := action.(api.BidiAction) + if !ok { + return nil, fmt.Errorf("%q is registered but is not an agent", ref.Name) + } + return agent, nil +} + +// Agents is a middleware that enables sub-agent delegation. +// +// For every configured agent it injects a dedicated delegation tool (e.g. +// delegate_to_researcher) whose description is the agent's configured +// description or, in the system prompt, the description auto-discovered from the +// registry. A block listing the available agents is appended to the +// system prompt. +// +// When the model calls a delegation tool the middleware resolves the target +// agent from the registry (via the [github.com/firebase/genkit/go/genkit.Genkit] +// instance carried on the context), optionally forwards recent conversation +// history, runs the sub-agent with the task, and returns its response as the +// tool result. +// +// Artifact handling follows [Agents.ArtifactStrategy]: ArtifactStrategyInline +// (default) returns artifact content in the tool result and merges artifacts +// into the parent session; ArtifactStrategySession merges into the session only +// and returns names. Merged artifacts are namespaced by an invocation ID +// (_/) and tagged with the source agent. +// +// If a sub-agent interrupts (e.g. for human input) it is reported back to the +// orchestrator as a normal tool response, not propagated as an interrupt: there +// is no stateful sub-agent runtime to resume into, so interactive sub-agent +// interaction is a future feature. +// +// The middleware resolves agents through genkit.FromContext, which is seeded by +// genkit.Generate and by agents defined via genkit.DefineAgent. It is therefore +// typically attached to an orchestrator agent (or a genkit.Generate call). +// +// Usage: +// +// orchestrator := genkit.DefineAgent(g, "orchestrator", +// aix.FromInline( +// ai.WithModelName("googleai/gemini-flash-latest"), +// ai.WithSystem("You are a helpful project assistant."), +// ai.WithUse( +// &middleware.Agents{ +// Agents: []aix.AgentRef{ +// {Name: "researcher"}, // by name +// coderAgent.Ref(), // by instance (carries its description) +// }, +// MaxDelegations: 5, +// HistoryLength: 4, +// ArtifactStrategy: middleware.ArtifactStrategySession, +// }, +// &middleware.Artifacts{}, +// ), +// ), +// ) +type Agents struct { + // Agents lists the sub-agents available for delegation: by name + // (aix.AgentRef{Name: ...}) or as a captured instance (agentValue.Ref()). + // At least one is required. + Agents []aix.AgentRef `json:"agents,omitempty"` + // ToolPrefix is the prefix for generated delegation tool names. A nil value + // defaults to "delegate_to" (tools become delegate_to_); a pointer to + // the empty string uses bare agent names. + ToolPrefix *string `json:"toolPrefix,omitempty"` + // MaxDelegations caps the number of sub-agent delegations per generate call, + // preventing runaway delegation loops. 0 means unlimited. + MaxDelegations int `json:"maxDelegations,omitempty"` + // HistoryLength is the number of recent user/model messages forwarded to a + // sub-agent as context. 0 means only the task description is sent. History is + // forwarded only to client-managed sub-agents (those without a session + // store); server-managed sub-agents receive only the task. + HistoryLength int `json:"historyLength,omitempty"` + // ArtifactStrategy controls how sub-agent artifacts are surfaced. Defaults to + // ArtifactStrategyInline. + ArtifactStrategy ArtifactStrategy `json:"artifactStrategy,omitempty"` +} + +func (a *Agents) Name() string { return provider + "/agents" } + +// agentsState is the per-generate mutable state shared by the delegation tools +// and the generate hook. New allocates a fresh one per call, and a mutex guards +// it because delegation tools can run concurrently (parallel tool calls). +type agentsState struct { + mu sync.Mutex + // delegations counts delegations made so far, enforcing MaxDelegations and + // providing the per-invocation number used to namespace artifacts. + delegations int + // conversation is the latest request message list, captured each turn for + // optional history forwarding. + conversation []*ai.Message +} + +// New validates the configuration and returns the hooks: a delegation tool per +// agent plus a generate hook that injects the system prompt and +// captures conversation history. +func (a *Agents) New(ctx context.Context) (*ai.Hooks, error) { + if len(a.Agents) == 0 { + return nil, fmt.Errorf("agents middleware requires at least one agent in the \"agents\" option") + } + for _, ref := range a.Agents { + if ref.Name == "" { + return nil, fmt.Errorf("agents middleware: every agent reference must have a name") + } + } + + prefix := a.prefix() + st := &agentsState{} + + tools := make([]ai.Tool, 0, len(a.Agents)) + for _, ref := range a.Agents { + desc := ref.Description + if desc == "" { + desc = fmt.Sprintf("Delegates a task to the %q sub-agent.", ref.Name) + } + tools = append(tools, ai.NewTool(makeToolName(prefix, ref.Name), desc, a.delegate(ref, st))) + } + + wrapGenerate := func(ctx context.Context, params *ai.GenerateParams, next ai.GenerateNext) (*ai.ModelResponse, error) { + // Capture the latest messages for optional history forwarding. The + // delegation count is intentionally not reset here: this hook runs on + // every tool-loop turn, but the count must accumulate across the whole + // generate call (it starts at 0 when New allocates st). + st.mu.Lock() + st.conversation = params.Request.Messages + st.mu.Unlock() + + instructions := buildAgentsInstructions(genkit.FromContext(ctx), a.Agents, prefix) + params.Request = injectSystemText(params.Request, agentsMarker, instructions) + return next(ctx, params) + } + + return &ai.Hooks{ + Tools: tools, + WrapGenerate: wrapGenerate, + }, nil +} + +// delegateInput is the input schema for a delegation tool. +type delegateInput struct { + Task string `json:"task" jsonschema:"description=A clear, self-contained description of the task to delegate."` +} + +// delegationResult is the output of a delegation tool. +type delegationResult struct { + // Response is the sub-agent's text response. + Response string `json:"response"` + // Artifacts are the sub-agent's artifacts. Content is populated only under + // ArtifactStrategyInline. + Artifacts []delegatedArtifact `json:"artifacts,omitempty"` +} + +type delegatedArtifact struct { + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` +} + +// delegate builds the delegation tool function for one sub-agent. +func (a *Agents) delegate(ref aix.AgentRef, st *agentsState) func(*ai.ToolContext, delegateInput) (delegationResult, error) { + return func(tc *ai.ToolContext, in delegateInput) (delegationResult, error) { + // Guard rail: enforce the delegation cap and reserve this delegation's + // number, atomically, before doing any work. + st.mu.Lock() + if a.MaxDelegations > 0 && st.delegations >= a.MaxDelegations { + st.mu.Unlock() + return delegationResult{Response: fmt.Sprintf( + "Delegation limit reached (%d). Complete the task using information already gathered.", + a.MaxDelegations)}, nil + } + st.delegations++ + invocationNum := st.delegations + history := recentTextHistory(st.conversation, a.HistoryLength) + st.mu.Unlock() + + agent, err := resolveAgent(genkit.FromContext(tc), ref) + if err != nil { + return delegationResult{Response: "Error: " + err.Error()}, nil + } + + // History rides in client-managed init state, which server-managed + // agents reject; forward it only to client-managed sub-agents. + if len(history) > 0 && !isClientManaged(agent) { + history = nil + } + + out, err := runSubAgent(tc, agent, in.Task, history) + if err != nil { + // The agent runtime resolves failures and interrupts gracefully (see + // below), so this only fires for exceptions outside that handling + // (e.g. a rejected init payload). Surface it as tool output. + return delegationResult{Response: fmt.Sprintf("Error calling agent %q: %v", ref.Name, err)}, nil + } + + switch out.FinishReason { + case aix.AgentFinishReasonInterrupted: + // Reported as text, not propagated: there is no stateful sub-agent + // runtime to resume into, so the orchestrator could never satisfy it. + return delegationResult{Response: fmt.Sprintf( + "Sub-agent %q interrupted for additional input and could not complete the "+ + "task. Interactive sub-agent interrupts are not currently supported; try "+ + "delegating a more self-contained task.", ref.Name)}, nil + case aix.AgentFinishReasonFailed: + msg := "Unknown sub-agent failure." + if out.Error != nil && out.Error.Message != "" { + msg = out.Error.Message + } + return delegationResult{Response: fmt.Sprintf("Error calling agent %q: %s", ref.Name, msg)}, nil + } + + result := delegationResult{Response: messageText(out.Message)} + if result.Response == "" { + result.Response = "(no response)" + } + + subArtifacts := namedArtifacts(out.Artifacts) + if len(subArtifacts) > 0 { + invocationID := fmt.Sprintf("%s_%d", ref.Name, invocationNum) + // Merge into the parent session under both strategies (no-op if there + // is no active session, e.g. a plain genkit.Generate call). + mergeArtifacts(tc, ref.Name, invocationID, subArtifacts) + result.Artifacts = delegatedArtifacts(invocationID, subArtifacts, a.strategy()) + } + return result, nil + } +} + +// runSubAgent runs the agent one-shot with the task. Agents are bidi actions, +// so this always goes through RunBidiJSON: with no history the init is empty (a +// fresh one-shot session); with history it carries the messages as client- +// managed init state, which callers forward only to client-managed agents. The +// output is decoded with json.RawMessage as the custom-state type since the +// sub-agent's State is unknown here. +func runSubAgent(ctx context.Context, agent api.BidiAction, task string, history []*ai.Message) (*aix.AgentOutput[json.RawMessage], error) { + inputJSON, err := json.Marshal(&aix.AgentInput{Message: ai.NewUserTextMessage(task)}) + if err != nil { + return nil, err + } + + var initJSON json.RawMessage + if len(history) > 0 { + initJSON, err = json.Marshal(aix.AgentInit[json.RawMessage]{ + State: &aix.SessionState[json.RawMessage]{Messages: history}, + }) + if err != nil { + return nil, err + } + } + + res, err := agent.RunBidiJSON(ctx, inputJSON, nil, &api.BidiSessionOptions{Init: initJSON}) + if err != nil { + return nil, err + } + + var out aix.AgentOutput[json.RawMessage] + if err := json.Unmarshal(res.Result, &out); err != nil { + return nil, err + } + return &out, nil +} + +// isClientManaged reports whether the agent owns its state on the client (no +// session store), which is the only case that accepts seeded init state. +func isClientManaged(agent api.BidiAction) bool { + meta := agent.Desc().Metadata + if meta == nil { + return false + } + switch m := meta["agent"].(type) { + case aix.AgentMetadata: + return m.StateManagement == aix.AgentStateManagementClient + case *aix.AgentMetadata: + return m.StateManagement == aix.AgentStateManagementClient + case map[string]any: + s, _ := m["stateManagement"].(string) + return aix.AgentStateManagement(s) == aix.AgentStateManagementClient + default: + return false + } +} + +// mergeArtifacts namespaces the sub-agent's artifacts by invocation ID, tags +// them with their source, and merges them into the active session. It is a no-op +// when there is no active session. +func mergeArtifacts(ctx context.Context, source, invocationID string, arts []*aix.Artifact) { + store := aix.ArtifactStoreFromContext(ctx) + if store == nil { + return + } + namespaced := make([]*aix.Artifact, 0, len(arts)) + for _, a := range arts { + md := make(map[string]any, len(a.Metadata)+2) + for k, v := range a.Metadata { + md[k] = v + } + md["source"] = source + md["invocationId"] = invocationID + namespaced = append(namespaced, &aix.Artifact{ + Name: invocationID + "/" + a.Name, + Parts: a.Parts, + Metadata: md, + }) + } + store.AddArtifacts(namespaced...) +} + +// delegatedArtifacts builds the tool-result artifact list, including content +// only under the inline strategy. +func delegatedArtifacts(invocationID string, arts []*aix.Artifact, strategy ArtifactStrategy) []delegatedArtifact { + out := make([]delegatedArtifact, 0, len(arts)) + for _, a := range arts { + da := delegatedArtifact{Name: invocationID + "/" + a.Name} + if strategy == ArtifactStrategyInline { + da.Content = artifactText(a) + } + out = append(out, da) + } + return out +} + +// prefix resolves the delegation tool-name prefix, defaulting to "delegate_to". +func (a *Agents) prefix() string { + if a.ToolPrefix == nil { + return defaultToolPrefix + } + return *a.ToolPrefix +} + +// strategy resolves the artifact strategy, defaulting to inline. +func (a *Agents) strategy() ArtifactStrategy { + if a.ArtifactStrategy == ArtifactStrategySession { + return ArtifactStrategySession + } + return ArtifactStrategyInline +} + +// makeToolName builds a delegation tool name from the prefix and agent name. An +// empty prefix yields the bare agent name. +func makeToolName(prefix, agentName string) string { + if prefix == "" { + return agentName + } + return prefix + "_" + agentName +} + +// buildAgentsInstructions renders the system prompt block. g may be +// nil (e.g. outside an agent/Generate context), in which case only configured +// descriptions are used. +func buildAgentsInstructions(g *genkit.Genkit, refs []aix.AgentRef, prefix string) string { + var b strings.Builder + b.WriteString("\n") + b.WriteString("You can delegate tasks to specialized sub-agents using their delegation tools:\n") + for _, ref := range refs { + desc := ref.Description + if desc == "" && g != nil { + desc = discoverDescription(g, ref.Name) + } + if desc == "" { + desc = "No description available." + } + fmt.Fprintf(&b, " - %s: %s\n", makeToolName(prefix, ref.Name), desc) + } + b.WriteString("\n") + b.WriteString("When a task is better handled by a specialized agent, delegate it using the ") + b.WriteString("appropriate tool. Provide a clear, self-contained task description.\n") + b.WriteString("") + return b.String() +} + +// discoverDescription returns the agent's description from its action +// descriptor, falling back to the backing prompt's description, or "" if none. +func discoverDescription(g *genkit.Genkit, name string) string { + for _, key := range []string{"/agent/" + name, "/prompt/" + name} { + if action := genkit.LookupAction(g, key); action != nil { + if d := action.Desc().Description; d != "" { + return d + } + } + } + return "" +} + +// recentTextHistory returns up to n of the most recent user/model messages, +// each reduced to its non-empty text parts. Tool and tool-request parts are +// dropped: a model message mid-tool-loop can carry a toolRequest part with no +// matching response, which would confuse the sub-agent model. Returns nil when +// n <= 0. +func recentTextHistory(msgs []*ai.Message, n int) []*ai.Message { + if n <= 0 { + return nil + } + var filtered []*ai.Message + for _, m := range msgs { + if m == nil || (m.Role != ai.RoleUser && m.Role != ai.RoleModel) { + continue + } + var parts []*ai.Part + for _, p := range m.Content { + if p != nil && p.IsText() && p.Text != "" { + parts = append(parts, ai.NewTextPart(p.Text)) + } + } + if len(parts) > 0 { + filtered = append(filtered, &ai.Message{Role: m.Role, Content: parts}) + } + } + if len(filtered) > n { + filtered = filtered[len(filtered)-n:] + } + return filtered +} + +// namedArtifacts returns the artifacts that have a non-empty name. +func namedArtifacts(arts []*aix.Artifact) []*aix.Artifact { + out := make([]*aix.Artifact, 0, len(arts)) + for _, a := range arts { + if a != nil && a.Name != "" { + out = append(out, a) + } + } + return out +} + +// messageText joins a message's non-empty text parts with newlines. +func messageText(m *ai.Message) string { + if m == nil { + return "" + } + var b strings.Builder + for _, p := range m.Content { + if p != nil && p.IsText() && p.Text != "" { + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(p.Text) + } + } + return b.String() +} diff --git a/go/plugins/middleware/exp/agents_test.go b/go/plugins/middleware/exp/agents_test.go new file mode 100644 index 0000000000..76150fb84b --- /dev/null +++ b/go/plugins/middleware/exp/agents_test.go @@ -0,0 +1,574 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" +) + +// toolModel defines a model with full tool/multiturn support backed by fn. +func toolModel(t *testing.T, g *genkit.Genkit, name string, fn ai.ModelFunc) ai.Model { + t.Helper() + return genkit.DefineModel(g, name, &ai.ModelOptions{ + Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true, Tools: true}, + }, fn) +} + +// textResp is a model response carrying a single model text message. +func textResp(req *ai.ModelRequest, text string) *ai.ModelResponse { + return &ai.ModelResponse{Request: req, Message: ai.NewModelTextMessage(text)} +} + +// toolReqResp is a model response that issues the given tool calls. +func toolReqResp(req *ai.ModelRequest, calls ...*ai.ToolRequest) *ai.ModelResponse { + parts := make([]*ai.Part, 0, len(calls)) + for _, c := range calls { + parts = append(parts, ai.NewToolRequestPart(c)) + } + return &ai.ModelResponse{Request: req, Message: &ai.Message{Role: ai.RoleModel, Content: parts}} +} + +// systemText concatenates the text parts of a system message. +func systemText(m *ai.Message) string { + var b strings.Builder + for _, p := range m.Content { + if p != nil && p.IsText() { + b.WriteString(p.Text) + b.WriteByte('\n') + } + } + return b.String() +} + +// hasToolResponse reports whether any message carries a tool response. +func hasToolResponse(msgs []*ai.Message) bool { + for _, m := range msgs { + for _, p := range m.Content { + if p.IsToolResponse() { + return true + } + } + } + return false +} + +// delegateOnceModel calls toolName once with the given task, then returns +// "done" after it sees any tool response. +func delegateOnceModel(t *testing.T, g *genkit.Genkit, name, toolName, task string) ai.Model { + return toolModel(t, g, name, func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + if hasToolResponse(req.Messages) { + return textResp(req, "done"), nil + } + return toolReqResp(req, &ai.ToolRequest{Name: toolName, Input: map[string]any{"task": task}}), nil + }) +} + +// decodeDelegation re-decodes a tool response output into a delegationResult, +// tolerating either the raw struct or a JSON-normalized map. +func decodeDelegation(t *testing.T, v any) delegationResult { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal tool output: %v", err) + } + var dr delegationResult + if err := json.Unmarshal(b, &dr); err != nil { + t.Fatalf("unmarshal delegationResult: %v", err) + } + return dr +} + +// delegationResponses collects every delegation tool response for toolName. +func delegationResponses(t *testing.T, msgs []*ai.Message, toolName string) []delegationResult { + t.Helper() + var out []delegationResult + for _, m := range msgs { + for _, p := range m.Content { + if p.IsToolResponse() && p.ToolResponse != nil && p.ToolResponse.Name == toolName { + out = append(out, decodeDelegation(t, p.ToolResponse.Output)) + } + } + } + return out +} + +func TestAgentsValidation(t *testing.T) { + if _, err := (&Agents{}).New(ctx); err == nil { + t.Error("expected error when no agents are configured") + } + if _, err := (&Agents{Agents: []aix.AgentRef{{Name: ""}}}).New(ctx); err == nil { + t.Error("expected error when an agent reference has no name") + } + if _, err := (&Agents{Agents: []aix.AgentRef{{Name: "ok"}}}).New(ctx); err != nil { + t.Errorf("unexpected error for a valid config: %v", err) + } +} + +func TestAgentsInjectsSystemPrompt(t *testing.T) { + g := newTestGenkit(t) + + // researcher's description is auto-discovered from its action descriptor. + genkit.DefineAgent[any](g, "researcher", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/researcher", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return textResp(req, "researched"), nil + }))), + aix.WithDescription[any]("Searches the web and summarizes findings."), + ) + + var captured []*ai.Message + orch := toolModel(t, g, "test/orch", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + captured = req.Messages + return textResp(req, "ok"), nil + }) + + mw := &Agents{Agents: []aix.AgentRef{ + {Name: "researcher"}, // discovered description + {Name: "coder", Description: "Writes Go code."}, // explicit override (agent need not exist for the listing) + }} + if _, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("hi"), ai.WithUse(mw)); err != nil { + t.Fatal(err) + } + + sys := findSystem(captured) + if sys == nil { + t.Fatalf("expected a system message; got %v", captured) + } + text := systemText(sys) + for _, want := range []string{ + "delegate_to_researcher: Searches the web and summarizes findings.", + "delegate_to_coder: Writes Go code.", + "", + } { + if !strings.Contains(text, want) { + t.Errorf("system prompt missing %q; got:\n%s", want, text) + } + } +} + +func TestAgentsDelegationRunsSubAgent(t *testing.T) { + g := newTestGenkit(t) + + genkit.DefineAgent[any](g, "researcher", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/researcher", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return textResp(req, "research complete"), nil + }))), + ) + + orch := delegateOnceModel(t, g, "test/orch", "delegate_to_researcher", "look into X") + mw := &Agents{Agents: []aix.AgentRef{{Name: "researcher"}}} + + resp, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("research X"), ai.WithUse(mw)) + if err != nil { + t.Fatal(err) + } + + got := delegationResponses(t, resp.History(), "delegate_to_researcher") + if len(got) != 1 { + t.Fatalf("expected 1 delegation response, got %d", len(got)) + } + if got[0].Response != "research complete" { + t.Errorf("delegation response = %q, want %q", got[0].Response, "research complete") + } +} + +func TestAgentsUnknownAgentReportsError(t *testing.T) { + g := newTestGenkit(t) + + orch := delegateOnceModel(t, g, "test/orch", "delegate_to_ghost", "do it") + mw := &Agents{Agents: []aix.AgentRef{{Name: "ghost"}}} // never defined + + resp, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("go"), ai.WithUse(mw)) + if err != nil { + t.Fatal(err) + } + got := delegationResponses(t, resp.History(), "delegate_to_ghost") + if len(got) != 1 || !strings.Contains(got[0].Response, "not found") { + t.Fatalf("expected a 'not found' delegation response, got %+v", got) + } +} + +func TestAgentsToolPrefix(t *testing.T) { + bare := "" + custom := "ask" + cases := []struct { + name string + prefix *string + want string + }{ + {"default", nil, "delegate_to_researcher"}, + {"custom", &custom, "ask_researcher"}, + {"bare", &bare, "researcher"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + g := newTestGenkit(t) + genkit.DefineAgent[any](g, "researcher", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/sub-"+tc.name, func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return textResp(req, "ok"), nil + }))), + ) + orch := delegateOnceModel(t, g, "test/orch-"+tc.name, tc.want, "task") + mw := &Agents{Agents: []aix.AgentRef{{Name: "researcher"}}, ToolPrefix: tc.prefix} + + resp, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("go"), ai.WithUse(mw)) + if err != nil { + t.Fatal(err) + } + if got := delegationResponses(t, resp.History(), tc.want); len(got) != 1 { + t.Fatalf("expected delegation via tool %q, got %d responses", tc.want, len(got)) + } + }) + } +} + +func TestAgentsMaxDelegations(t *testing.T) { + g := newTestGenkit(t) + + genkit.DefineAgent[any](g, "researcher", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/researcher", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return textResp(req, "did work"), nil + }))), + ) + + // Issue two delegations in a single turn; with MaxDelegations=1 exactly one + // must be refused. + orch := toolModel(t, g, "test/orch", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + if hasToolResponse(req.Messages) { + return textResp(req, "done"), nil + } + return toolReqResp(req, + &ai.ToolRequest{Name: "delegate_to_researcher", Input: map[string]any{"task": "a"}}, + &ai.ToolRequest{Name: "delegate_to_researcher", Input: map[string]any{"task": "b"}}, + ), nil + }) + + mw := &Agents{Agents: []aix.AgentRef{{Name: "researcher"}}, MaxDelegations: 1} + resp, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("go"), ai.WithUse(mw)) + if err != nil { + t.Fatal(err) + } + + got := delegationResponses(t, resp.History(), "delegate_to_researcher") + if len(got) != 2 { + t.Fatalf("expected 2 delegation responses, got %d", len(got)) + } + var real, limited int + for _, r := range got { + switch { + case r.Response == "did work": + real++ + case strings.Contains(r.Response, "Delegation limit reached"): + limited++ + default: + t.Errorf("unexpected delegation response: %q", r.Response) + } + } + if real != 1 || limited != 1 { + t.Errorf("got real=%d limited=%d, want 1 and 1", real, limited) + } +} + +func TestAgentsForwardsHistory(t *testing.T) { + g := newTestGenkit(t) + + // The sub-agent records the messages its model receives. + var subMessages []*ai.Message + genkit.DefineAgent[any](g, "researcher", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/researcher", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + subMessages = req.Messages + return textResp(req, "noted"), nil + }))), + ) + + orch := delegateOnceModel(t, g, "test/orch", "delegate_to_researcher", "summarize the discussion") + mw := &Agents{Agents: []aix.AgentRef{{Name: "researcher"}}, HistoryLength: 4} + + _, err := genkit.Generate(ctx, g, + ai.WithModel(orch), + ai.WithMessages( + ai.NewUserTextMessage("the secret code is platypus"), + ai.NewModelTextMessage("understood"), + ), + ai.WithPrompt("now delegate"), + ai.WithUse(mw), + ) + if err != nil { + t.Fatal(err) + } + + var joined strings.Builder + for _, m := range subMessages { + joined.WriteString(messageText(m)) + joined.WriteByte('\n') + } + if !strings.Contains(joined.String(), "platypus") { + t.Errorf("sub-agent did not receive forwarded history; saw:\n%s", joined.String()) + } +} + +func TestAgentsSubAgentFailureReported(t *testing.T) { + g := newTestGenkit(t) + + // A custom sub-agent whose turn fails. + genkit.DefineCustomAgent[any](g, "researcher", + func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { + return nil, errors.New("kaboom") + }) + if err != nil { + return nil, err + } + return &aix.AgentResult{}, nil + }, + ) + + orch := delegateOnceModel(t, g, "test/orch", "delegate_to_researcher", "go") + mw := &Agents{Agents: []aix.AgentRef{{Name: "researcher"}}} + + resp, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("go"), ai.WithUse(mw)) + if err != nil { + t.Fatal(err) + } + got := delegationResponses(t, resp.History(), "delegate_to_researcher") + if len(got) != 1 || !strings.Contains(got[0].Response, "Error calling agent") { + t.Fatalf("expected an error delegation response, got %+v", got) + } +} + +// TestAgentsArtifactStrategies verifies that, run inside an orchestrator agent +// (so a session exists), sub-agent artifacts are merged into the parent session +// under both strategies and that inline includes content while session does not. +func TestAgentsArtifactStrategies(t *testing.T) { + for _, strategy := range []ArtifactStrategy{ArtifactStrategyInline, ArtifactStrategySession} { + t.Run(string(strategy), func(t *testing.T) { + g := newTestGenkit(t) + + // A custom sub-agent that produces an artifact. + genkit.DefineCustomAgent[any](g, "writer", + func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { + resp.SendArtifact(&aix.Artifact{ + Name: "report.md", + Parts: []*ai.Part{ai.NewTextPart("the report body")}, + }) + sess.AddMessages(ai.NewModelTextMessage("wrote the report")) + return &aix.TurnResult{FinishReason: aix.AgentFinishReasonStop}, nil + }) + if err != nil { + return nil, err + } + return &aix.AgentResult{ + Message: ai.NewModelTextMessage("wrote the report"), + Artifacts: sess.Artifacts(), + }, nil + }, + ) + + delegating := delegateOnceModel(t, g, "test/orch-model-"+string(strategy), "delegate_to_writer", "write a report") + + // The orchestrator is itself an agent, so the delegation runs within + // a session that artifacts can merge into. Capture the inner generate + // history to inspect the delegation tool response. + var innerHistory []*ai.Message + orchestrator := genkit.DefineCustomAgent[any](g, "orchestrator", + func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + var last *ai.Message + err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { + r, err := genkit.Generate(ctx, g, + ai.WithModel(delegating), + ai.WithMessages(input.Message), + ai.WithUse(&Agents{ + Agents: []aix.AgentRef{{Name: "writer"}}, + ArtifactStrategy: strategy, + }), + ) + if err != nil { + return nil, err + } + innerHistory = r.History() + last = r.Message + return &aix.TurnResult{FinishReason: aix.AgentFinishReasonStop}, nil + }) + if err != nil { + return nil, err + } + return &aix.AgentResult{Message: last, Artifacts: sess.Artifacts()}, nil + }, + ) + + out, err := orchestrator.RunText(ctx, "please produce a report") + if err != nil { + t.Fatal(err) + } + + // The sub-agent artifact is merged into the parent session, namespaced. + if !hasArtifactNamed(out.Artifacts, "writer_1/report.md") { + t.Errorf("expected merged artifact %q in parent session; got %v", "writer_1/report.md", artifactNames(out.Artifacts)) + } + + // Inline carries content in the tool result; session does not. + got := delegationResponses(t, innerHistory, "delegate_to_writer") + if len(got) != 1 || len(got[0].Artifacts) != 1 { + t.Fatalf("expected 1 delegation response with 1 artifact, got %+v", got) + } + content := got[0].Artifacts[0].Content + if strategy == ArtifactStrategyInline && !strings.Contains(content, "the report body") { + t.Errorf("inline strategy should include artifact content, got %q", content) + } + if strategy == ArtifactStrategySession && content != "" { + t.Errorf("session strategy should omit artifact content, got %q", content) + } + }) + } +} + +func TestAgentRefCapturesNameAndDescription(t *testing.T) { + g := newTestGenkit(t) + a := genkit.DefineAgent[any](g, "writer", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/writer", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return textResp(req, "x"), nil + }))), + aix.WithDescription[any]("Writes things."), + ) + + ref := a.Ref() + if ref.Name != "writer" { + t.Errorf("Name = %q, want %q", ref.Name, "writer") + } + if ref.Description != "Writes things." { + t.Errorf("Description = %q, want %q", ref.Description, "Writes things.") + } +} + +func TestAgentsDelegatesViaRef(t *testing.T) { + g := newTestGenkit(t) + researcher := genkit.DefineAgent[any](g, "researcher", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/researcher", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return textResp(req, "ref result"), nil + }))), + ) + + orch := delegateOnceModel(t, g, "test/orch", "delegate_to_researcher", "go") + mw := &Agents{Agents: []aix.AgentRef{researcher.Ref()}} + + resp, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("research"), ai.WithUse(mw)) + if err != nil { + t.Fatal(err) + } + got := delegationResponses(t, resp.History(), "delegate_to_researcher") + if len(got) != 1 || got[0].Response != "ref result" { + t.Fatalf("delegation via Ref failed: %+v", got) + } +} + +func TestAgentsRefDescriptionTakesPrecedence(t *testing.T) { + g := newTestGenkit(t) + a := genkit.DefineAgent[any](g, "writer", + aix.FromInline(ai.WithModel(toolModel(t, g, "test/writer", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + return textResp(req, "x"), nil + }))), + aix.WithDescription[any]("Original description."), + ) + + ref := a.Ref() + ref.Description = "Overridden in config." // user override on top of the instance + + var captured []*ai.Message + orch := toolModel(t, g, "test/orch", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + captured = req.Messages + return textResp(req, "ok"), nil + }) + + if _, err := genkit.Generate(ctx, g, ai.WithModel(orch), ai.WithPrompt("hi"), ai.WithUse(&Agents{Agents: []aix.AgentRef{ref}})); err != nil { + t.Fatal(err) + } + text := systemText(findSystem(captured)) + if !strings.Contains(text, "Overridden in config.") { + t.Errorf("system prompt missing the override; got:\n%s", text) + } + if strings.Contains(text, "Original description.") { + t.Errorf("override should replace the instance description; got:\n%s", text) + } +} + +// TestAgentsConfigSerialization guards the JSON-dispatch path used by the Dev +// UI: schema inference must not panic, and the config must round-trip. +func TestAgentsConfigSerialization(t *testing.T) { + _ = ai.NewMiddleware("agents", &Agents{}) // must not panic on schema inference + + prefix := "ask" + cfg := &Agents{ + Agents: []aix.AgentRef{{Name: "researcher"}, {Name: "coder", Description: "Writes Go."}}, + ToolPrefix: &prefix, + MaxDelegations: 3, + HistoryLength: 2, + ArtifactStrategy: ArtifactStrategySession, + } + b, err := json.Marshal(cfg) + if err != nil { + t.Fatal(err) + } + var got Agents + if err := json.Unmarshal(b, &got); err != nil { + t.Fatal(err) + } + if len(got.Agents) != 2 || got.Agents[0].Name != "researcher" || got.Agents[1].Description != "Writes Go." { + t.Errorf("agents lost in round trip: %+v", got.Agents) + } + if got.ToolPrefix == nil || *got.ToolPrefix != "ask" { + t.Errorf("toolPrefix lost in round trip: %v", got.ToolPrefix) + } + if got.ArtifactStrategy != ArtifactStrategySession { + t.Errorf("artifactStrategy lost in round trip: %q", got.ArtifactStrategy) + } +} + +// statemessages returns the conversation messages from a client-managed agent +// output's state. +func statemessages(out *aix.AgentOutput[any]) []*ai.Message { + if out == nil || out.State == nil { + return nil + } + return out.State.Messages +} + +func hasArtifactNamed(arts []*aix.Artifact, name string) bool { + for _, a := range arts { + if a != nil && a.Name == name { + return true + } + } + return false +} + +func artifactNames(arts []*aix.Artifact) []string { + names := make([]string, 0, len(arts)) + for _, a := range arts { + if a != nil { + names = append(names, a.Name) + } + } + return names +} diff --git a/go/plugins/middleware/exp/artifacts.go b/go/plugins/middleware/exp/artifacts.go new file mode 100644 index 0000000000..18c3bc443a --- /dev/null +++ b/go/plugins/middleware/exp/artifacts.go @@ -0,0 +1,209 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "fmt" + "strings" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" +) + +// artifactsMarker tags the system prompt part injected by this middleware so it +// can be refreshed each turn as the session's artifacts change. +const artifactsMarker = "artifacts-listing" + +// Artifacts is a middleware that gives the model tools to interact with session +// artifacts and injects a listing of available artifacts into the system prompt. +// +// It provides: +// +// - read_artifact: reads an artifact by name from the session and returns its +// text content. +// - write_artifact (unless Readonly): creates or updates a session artifact. +// Artifacts are deduplicated by name, so writing to an existing name +// replaces it. +// +// On every generate turn an block listing the names and sizes of +// the session's artifacts is injected into (or refreshed within) the system +// message, so the model knows what is available without spending context on the +// full content. +// +// This is useful standalone (e.g. a workspace-builder agent that creates files +// as artifacts) or combined with [Agents] using ArtifactStrategySession, where +// sub-agent artifacts are merged into the parent session and the model reaches +// them through these tools. +// +// Artifacts live on the active agent session, so this middleware only has an +// effect when generation runs inside an agent invocation (see +// [github.com/firebase/genkit/go/genkit.DefineAgent]). With no active session +// the tools report that gracefully and the listing is empty. +// +// Usage: +// +// builder := genkit.DefineAgent(g, "builder", +// aix.FromInline( +// ai.WithModelName("googleai/gemini-flash-latest"), +// ai.WithSystem("You are a code generator. Use write_artifact to create files."), +// ai.WithUse(&middleware.Artifacts{}), +// ), +// ) +type Artifacts struct { + // Readonly, when true, provides only the read_artifact tool; the model + // cannot create or update artifacts. Defaults to false. + Readonly bool `json:"readonly,omitempty"` +} + +func (a *Artifacts) Name() string { return provider + "/artifacts" } + +// New returns the hooks that register the artifact tools and inject the +// artifact listing into the system prompt on each turn. +func (a *Artifacts) New(ctx context.Context) (*ai.Hooks, error) { + tools := []ai.Tool{newReadArtifactTool()} + if !a.Readonly { + tools = append(tools, newWriteArtifactTool()) + } + + wrapGenerate := func(ctx context.Context, params *ai.GenerateParams, next ai.GenerateNext) (*ai.ModelResponse, error) { + var arts []*aix.Artifact + if store := aix.ArtifactStoreFromContext(ctx); store != nil { + arts = store.Artifacts() + } + params.Request = injectSystemText(params.Request, artifactsMarker, buildArtifactsListing(arts)) + return next(ctx, params) + } + + return &ai.Hooks{ + Tools: tools, + WrapGenerate: wrapGenerate, + }, nil +} + +type readArtifactInput struct { + Name string `json:"name" jsonschema:"description=The name of the artifact to read."` +} + +type readArtifactOutput struct { + Name string `json:"name"` + Content string `json:"content"` + Found bool `json:"found"` +} + +func newReadArtifactTool() ai.Tool { + return ai.NewTool("read_artifact", + "Reads the content of a named artifact from the session. Use this to "+ + "inspect artifacts produced by sub-agents or previously created artifacts.", + func(tc *ai.ToolContext, in readArtifactInput) (readArtifactOutput, error) { + store := aix.ArtifactStoreFromContext(tc) + if store == nil { + return readArtifactOutput{Name: in.Name, Content: "Error: no active session.", Found: false}, nil + } + for _, art := range store.Artifacts() { + if art.Name == in.Name { + return readArtifactOutput{Name: in.Name, Content: artifactText(art), Found: true}, nil + } + } + return readArtifactOutput{Name: in.Name, Content: fmt.Sprintf("Artifact %q not found.", in.Name), Found: false}, nil + }) +} + +type writeArtifactInput struct { + Name string `json:"name" jsonschema:"description=A unique name for the artifact (e.g. a filename like report.md)."` + Content string `json:"content" jsonschema:"description=The full text content of the artifact."` +} + +type writeArtifactOutput struct { + Status string `json:"status"` +} + +func newWriteArtifactTool() ai.Tool { + return ai.NewTool("write_artifact", + "Creates or updates a named artifact in the session. If an artifact with "+ + "the same name already exists, it is replaced. Use this to produce "+ + "files, reports, code, or other deliverables.", + func(tc *ai.ToolContext, in writeArtifactInput) (writeArtifactOutput, error) { + store := aix.ArtifactStoreFromContext(tc) + if store == nil { + return writeArtifactOutput{Status: "Error: no active session."}, nil + } + store.AddArtifacts(&aix.Artifact{ + Name: in.Name, + Parts: []*ai.Part{ai.NewTextPart(in.Content)}, + }) + return writeArtifactOutput{Status: fmt.Sprintf("Artifact %q saved successfully.", in.Name)}, nil + }) +} + +// artifactText joins an artifact's text parts with newlines, skipping +// non-text and empty parts. +func artifactText(a *aix.Artifact) string { + var b strings.Builder + for _, p := range a.Parts { + if p == nil || !p.IsText() || p.Text == "" { + continue + } + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(p.Text) + } + return b.String() +} + +// buildArtifactsListing renders the system block listing the +// session's artifacts and their sizes. Artifacts are listed in session order, +// which is stable across turns so the injected text only changes when the set +// of artifacts does. +func buildArtifactsListing(arts []*aix.Artifact) string { + var b strings.Builder + b.WriteString("\n") + if len(arts) == 0 { + b.WriteString("No artifacts are currently available in the session.\n") + b.WriteString("") + return b.String() + } + b.WriteString("The following artifacts are available in the session. ") + b.WriteString("Use the read_artifact tool to view their content.\n") + for _, a := range arts { + name := a.Name + if name == "" { + name = "(unnamed)" + } + fmt.Fprintf(&b, " - %s", name) + if text := artifactText(a); len(text) > 0 { + fmt.Fprintf(&b, " (%d chars)", len(text)) + } + if src := artifactSource(a); src != "" { + fmt.Fprintf(&b, " [from: %s]", src) + } + b.WriteByte('\n') + } + b.WriteString("") + return b.String() +} + +// artifactSource returns the artifact's "source" metadata (set when an artifact +// originates from a sub-agent delegation), or "" if absent. +func artifactSource(a *aix.Artifact) string { + if a.Metadata == nil { + return "" + } + src, _ := a.Metadata["source"].(string) + return src +} diff --git a/go/plugins/middleware/exp/artifacts_test.go b/go/plugins/middleware/exp/artifacts_test.go new file mode 100644 index 0000000000..01565f3851 --- /dev/null +++ b/go/plugins/middleware/exp/artifacts_test.go @@ -0,0 +1,219 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" +) + +func toolResponseByName(t *testing.T, msgs []*ai.Message, name string) (any, bool) { + t.Helper() + for _, m := range msgs { + for _, p := range m.Content { + if p.IsToolResponse() && p.ToolResponse != nil && p.ToolResponse.Name == name { + return p.ToolResponse.Output, true + } + } + } + return nil, false +} + +func decodeReadArtifact(t *testing.T, v any) readArtifactOutput { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var o readArtifactOutput + if err := json.Unmarshal(b, &o); err != nil { + t.Fatalf("unmarshal readArtifactOutput: %v", err) + } + return o +} + +func TestArtifactsReadonlyOmitsWriteTool(t *testing.T) { + hooks, err := (&Artifacts{Readonly: true}).New(ctx) + if err != nil { + t.Fatal(err) + } + if len(hooks.Tools) != 1 || hooks.Tools[0].Name() != "read_artifact" { + t.Fatalf("readonly should expose only read_artifact, got %v", toolNames(hooks.Tools)) + } + + hooks, err = (&Artifacts{}).New(ctx) + if err != nil { + t.Fatal(err) + } + names := toolNames(hooks.Tools) + if len(names) != 2 || !contains(names, "read_artifact") || !contains(names, "write_artifact") { + t.Fatalf("default should expose read_artifact and write_artifact, got %v", names) + } +} + +func TestArtifactsWriteThenRead(t *testing.T) { + g := newTestGenkit(t) + + // The model writes an artifact, reads it back, then finishes. + model := toolModel(t, g, "test/artifact-model", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + var wrote, read bool + for _, m := range req.Messages { + for _, p := range m.Content { + if p.IsToolResponse() && p.ToolResponse != nil { + switch p.ToolResponse.Name { + case "write_artifact": + wrote = true + case "read_artifact": + read = true + } + } + } + } + switch { + case read: + return textResp(req, "done"), nil + case wrote: + return toolReqResp(req, &ai.ToolRequest{Name: "read_artifact", Input: map[string]any{"name": "report.md"}}), nil + default: + return toolReqResp(req, &ai.ToolRequest{Name: "write_artifact", Input: map[string]any{"name": "report.md", "content": "hello world"}}), nil + } + }) + + builder := genkit.DefineAgent[any](g, "builder", + aix.FromInline(ai.WithModel(model), ai.WithSystem("be a builder"), ai.WithUse(&Artifacts{})), + ) + + out, err := builder.RunText(ctx, "make a report") + if err != nil { + t.Fatal(err) + } + + // read_artifact returned the content written by write_artifact. + v, ok := toolResponseByName(t, statemessages(out), "read_artifact") + if !ok { + t.Fatal("no read_artifact tool response found") + } + read := decodeReadArtifact(t, v) + if !read.Found || read.Content != "hello world" { + t.Errorf("read_artifact = %+v, want found with content %q", read, "hello world") + } + + // The artifact persisted on the session. + if !hasArtifactNamed(out.Artifacts, "report.md") { + t.Errorf("expected artifact %q on session; got %v", "report.md", artifactNames(out.Artifacts)) + } +} + +func TestArtifactsSystemPromptListing(t *testing.T) { + g := newTestGenkit(t) + + var captured []*ai.Message + capture := toolModel(t, g, "test/capture", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + captured = req.Messages + return textResp(req, "ok"), nil + }) + + // A custom agent seeds an artifact, then generates with the Artifacts + // middleware so the listing reflects the seeded artifact. + lister := genkit.DefineCustomAgent[any](g, "lister", + func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { + resp.SendArtifact(&aix.Artifact{ + Name: "notes.md", + Parts: []*ai.Part{ai.NewTextPart("some notes here")}, + }) + if _, err := genkit.Generate(ctx, g, + ai.WithModel(capture), + ai.WithMessages(input.Message), + ai.WithUse(&Artifacts{}), + ); err != nil { + return nil, err + } + return &aix.TurnResult{FinishReason: aix.AgentFinishReasonStop}, nil + }) + if err != nil { + return nil, err + } + return &aix.AgentResult{Message: ai.NewModelTextMessage("listed")}, nil + }, + ) + + if _, err := lister.RunText(ctx, "what artifacts are there?"); err != nil { + t.Fatal(err) + } + + sys := findSystem(captured) + if sys == nil { + t.Fatalf("expected a system message; got %v", captured) + } + text := systemText(sys) + if !strings.Contains(text, "") || !strings.Contains(text, "notes.md") { + t.Errorf("system prompt missing artifact listing for notes.md; got:\n%s", text) + } +} + +func TestArtifactsNoSession(t *testing.T) { + g := newTestGenkit(t) + + // With a plain Generate call there is no agent session. + model := toolModel(t, g, "test/no-session", func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + for _, m := range req.Messages { + for _, p := range m.Content { + if p.IsToolResponse() { + return textResp(req, "done"), nil + } + } + } + return toolReqResp(req, &ai.ToolRequest{Name: "read_artifact", Input: map[string]any{"name": "x"}}), nil + }) + + resp, err := genkit.Generate(ctx, g, ai.WithModel(model), ai.WithPrompt("read x"), ai.WithUse(&Artifacts{})) + if err != nil { + t.Fatal(err) + } + v, ok := toolResponseByName(t, resp.History(), "read_artifact") + if !ok { + t.Fatal("no read_artifact tool response found") + } + read := decodeReadArtifact(t, v) + if read.Found || !strings.Contains(read.Content, "no active session") { + t.Errorf("expected a no-active-session result, got %+v", read) + } +} + +func toolNames(tools []ai.Tool) []string { + names := make([]string, 0, len(tools)) + for _, tl := range tools { + names = append(names, tl.Name()) + } + return names +} + +func contains(ss []string, s string) bool { + for _, x := range ss { + if x == s { + return true + } + } + return false +} diff --git a/go/plugins/middleware/exp/helpers_test.go b/go/plugins/middleware/exp/helpers_test.go new file mode 100644 index 0000000000..7a785b04a6 --- /dev/null +++ b/go/plugins/middleware/exp/helpers_test.go @@ -0,0 +1,43 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +var ctx = context.Background() + +// newTestGenkit returns a fresh Genkit instance for a test. +func newTestGenkit(t *testing.T) *genkit.Genkit { + t.Helper() + return genkit.Init(context.Background()) +} + +// findSystem returns the first system message, or nil. +func findSystem(msgs []*ai.Message) *ai.Message { + for _, m := range msgs { + if m.Role == ai.RoleSystem { + return m + } + } + return nil +} diff --git a/go/plugins/middleware/exp/inject.go b/go/plugins/middleware/exp/inject.go new file mode 100644 index 0000000000..4db1038c5f --- /dev/null +++ b/go/plugins/middleware/exp/inject.go @@ -0,0 +1,90 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import "github.com/firebase/genkit/go/ai" + +// injectSystemText returns a copy of req with text placed in a system-message +// part tagged by marker. The marker lets a middleware find its own injected +// text on later tool-loop iterations: +// +// - If a tagged part already exists, it is refreshed in place when text +// changed, or left untouched when identical. Middleware whose text is +// constant (e.g. a fixed tool listing) is injected once; middleware whose +// text varies per turn (e.g. a live artifact listing) is refreshed. +// - Otherwise the text is appended to the first system message. +// - Otherwise a new system message carrying the text is prepended. +// +// The request and its messages are copied before mutation, so req is unchanged. +func injectSystemText(req *ai.ModelRequest, marker, text string) *ai.ModelRequest { + newReq := *req + newReq.Messages = append([]*ai.Message(nil), req.Messages...) + + // Refresh an existing tagged part in place. + for i, msg := range newReq.Messages { + if msg == nil { + continue + } + for j, part := range msg.Content { + if !hasMarker(part, marker) { + continue + } + if part.Text == text { + return &newReq + } + msgCopy := msg.Clone() + msgCopy.Content[j] = systemTextPart(marker, text) + newReq.Messages[i] = msgCopy + return &newReq + } + } + + // Append to an existing system message. + for i, msg := range newReq.Messages { + if msg == nil || msg.Role != ai.RoleSystem { + continue + } + msgCopy := msg.Clone() + msgCopy.Content = append(msgCopy.Content, systemTextPart(marker, text)) + newReq.Messages[i] = msgCopy + return &newReq + } + + // Otherwise prepend a fresh system message. + newReq.Messages = append( + []*ai.Message{ai.NewSystemMessage(systemTextPart(marker, text))}, + newReq.Messages..., + ) + return &newReq +} + +// systemTextPart builds the text part that carries middleware-injected system +// text, tagged with marker so later iterations can find and refresh it. +func systemTextPart(marker, text string) *ai.Part { + p := ai.NewTextPart(text) + p.Metadata = map[string]any{marker: true} + return p +} + +// hasMarker reports whether p is a text part tagged with marker. +func hasMarker(p *ai.Part, marker string) bool { + if p == nil || !p.IsText() || p.Metadata == nil { + return false + } + v, ok := p.Metadata[marker].(bool) + return ok && v +} diff --git a/go/plugins/middleware/exp/plugin.go b/go/plugins/middleware/exp/plugin.go new file mode 100644 index 0000000000..2340c72a39 --- /dev/null +++ b/go/plugins/middleware/exp/plugin.go @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package exp provides experimental middleware for the agent APIs in +// [github.com/firebase/genkit/go/ai/exp]: [Agents] for sub-agent delegation and +// [Artifacts] for session artifact access. These middlewares are experimental +// and may change in any minor release, tracking the agent APIs they build on. +package exp + +import ( + "context" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" +) + +// provider names the experimental middleware plugin and prefixes the +// registered middleware names (e.g. genkit-middleware-exp/agents). +const provider = "genkit-middleware-exp" + +// Middleware provides the experimental agent middleware ([Agents], [Artifacts]) +// as a Genkit plugin. Register it with [genkit.WithPlugins] during +// [genkit.Init] to make them resolvable by name (e.g. for the Dev UI). Using +// them directly via [ai.WithUse] does not require the plugin. +type Middleware struct{} + +func (p *Middleware) Name() string { return provider } + +func (p *Middleware) Init(ctx context.Context) []api.Action { return nil } + +func (p *Middleware) Middlewares(ctx context.Context) ([]*ai.MiddlewareDesc, error) { + return []*ai.MiddlewareDesc{ + ai.NewMiddleware("Delegate tasks to registered sub-agents via per-agent tools", &Agents{}), + ai.NewMiddleware("Provide read/write tools for session artifacts", &Artifacts{}), + }, nil +} diff --git a/go/samples/basic-agents/chef.go b/go/samples/basic-agents/chef.go new file mode 100644 index 0000000000..5797b9dcc7 --- /dev/null +++ b/go/samples/basic-agents/chef.go @@ -0,0 +1,45 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" +) + +// ChatPromptInput is the input schema referenced by ./prompts/chef.prompt. +// Registering it via DefineSchemaFor (in main) lets the .prompt file refer to +// it by name in its YAML frontmatter. +type ChatPromptInput struct { + Personality string `json:"personality"` +} + +// definePromptAgent demonstrates DefineAgent with aix.FromPrompt. The +// prompt is loaded from ./prompts/.prompt by genkit's prompt +// registry. Defining the prompt in a file lets you tune model, config, +// schema, and template independently of the Go code — useful when prompt +// authors are not the same people writing the agent wiring. +// +// FromPrompt's argument is the default input passed to the prompt's +// Render on every turn; the inline-prompt variant has no per-turn input +// of its own. +func definePromptAgent(g *genkit.Genkit) *aix.Agent[any] { + const name = "chef" + return genkit.DefineAgent(g, name, + aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), + aix.WithSessionStore(mustStore(name)), + aix.WithDescription[any]("Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)"), + ) +} diff --git a/go/samples/basic-agents/coder.go b/go/samples/basic-agents/coder.go new file mode 100644 index 0000000000..04a9230135 --- /dev/null +++ b/go/samples/basic-agents/coder.go @@ -0,0 +1,67 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" +) + +// defineCustomAgent demonstrates DefineCustomAgent. The per-turn function +// is fully under your control: it picks the model, manages the message +// list, streams chunks back to the client, and decides what to put in the +// final result. Use this form when the prompt-backed agent loop doesn't +// fit (e.g. you want to pre/post-process every turn, swap models +// dynamically, or wire up custom tool plumbing). +// +// Even with full control over the loop, the framework still owns session +// state, snapshot writes, and the detach lifecycle. +func defineCustomAgent(g *genkit.Genkit) *aix.Agent[any] { + const name = "coder" + return genkit.DefineCustomAgent(g, name, + func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { + for chunk, err := range genkit.GenerateStream(ctx, g, + ai.WithModel(flashModel), + ai.WithSystem("You are a senior software engineer. Answer briefly. Use fenced code blocks when showing code."), + ai.WithMessages(sess.Messages()...), + ) { + if err != nil { + return nil, err + } + if chunk.Done { + sess.AddMessages(chunk.Response.Message) + // Report how the turn ended so the framework can + // forward it on the TurnEnd chunk and persist it + // on the snapshot. + return &aix.TurnResult{ + FinishReason: aix.AgentFinishReason(chunk.Response.FinishReason), + }, nil + } + resp.SendModelChunk(chunk.Chunk) + } + return nil, nil + }); err != nil { + return nil, err + } + return sess.Result(), nil + }, + aix.WithSessionStore(mustStore(name)), + aix.WithDescription[any]("Concise code helper (custom per-turn loop)"), + ) +} diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index e38e505d4d..08ab486752 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -12,18 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This sample demonstrates Genkit's agent APIs by defining three agents in -// three different styles and exposing all of them through a single CLI: +// This sample demonstrates Genkit's agent APIs by defining four agents and +// exposing all of them through a single CLI. Each agent lives in its own file: // -// - "pirate" uses DefineAgent + aix.FromInline. The prompt is declared -// inline next to the agent. -// - "chef" uses DefineAgent + aix.FromPrompt. The prompt is loaded from -// ./prompts/chef.prompt by the agent's name. -// - "coder" uses DefineCustomAgent. The per-turn loop (model selection, -// history management, streaming) is wired by hand. +// - "pirate" (pirate.go) uses DefineAgent + aix.FromInline. The prompt is +// declared inline next to the agent. +// - "chef" (chef.go) uses DefineAgent + aix.FromPrompt. The prompt is loaded +// from ./prompts/chef.prompt by the agent's name. +// - "coder" (coder.go) uses DefineCustomAgent. The per-turn loop (model +// selection, history management, streaming) is wired by hand. +// - "orchestrator" (orchestrator.go) uses the experimental Agents middleware +// to delegate to specialized sub-agents. // -// All three agents persist their conversation state to a per-agent -// FileSessionStore under ./.genkit/snapshots//. +// The first three persist their conversation state to a per-agent +// FileSessionStore under ./.genkit/snapshots//; the orchestrator does +// too, while its sub-agents run statelessly per delegation. // // To run: // @@ -57,7 +60,6 @@ import ( "os/signal" "syscall" - "github.com/firebase/genkit/go/ai" aix "github.com/firebase/genkit/go/ai/exp" "github.com/firebase/genkit/go/ai/exp/localstore" "github.com/firebase/genkit/go/genkit" @@ -65,30 +67,24 @@ import ( "google.golang.org/genai" ) -// ChatPromptInput is the input schema referenced by ./prompts/chef.prompt. -// Registering it via DefineSchemaFor lets the .prompt file refer to it by -// name in its YAML frontmatter. -type ChatPromptInput struct { - Personality string `json:"personality"` -} - func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - genkit.DefineSchemaFor[ChatPromptInput](g) + genkit.DefineSchemaFor[ChatPromptInput](g) // input schema for ./prompts/chef.prompt (see chef.go) - // Each define function registers an agent and returns it. The CLI - // drives all three through the same surface: a.Name() and - // a.Desc().Description for the list view, a.StreamBidi(...) to chat, - // and a.Store() for snapshot reads. Nothing the CLI does is tied to a - // concrete store type, so swapping in a different SessionStore would - // not touch a line of it. + // Each define function (in its own file) registers an agent and returns + // it. The CLI drives all of them through the same surface: a.Name() and + // a.Desc().Description for the list view, a.StreamBidi(...) to chat, and + // a.Store() for snapshot reads. Nothing the CLI does is tied to a concrete + // store type, so swapping in a different SessionStore would not touch a + // line of it. agents := []*aix.Agent[any]{ defineInlineAgent(g), definePromptAgent(g), defineCustomAgent(g), + defineOrchestratorAgent(g), } if err := runCLI(ctx, agents); err != nil { @@ -97,93 +93,13 @@ func main() { } } -// defineInlineAgent demonstrates DefineAgent with aix.FromInline. The -// prompt is declared right next to the agent definition; the registered -// prompt and the agent share a name. Each turn the framework renders the -// prompt, appends the conversation history, calls the model, and updates -// session state. This is the shortest path from "I want a chat agent" to -// a working one. -func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any] { - const name = "pirate" - return genkit.DefineAgent(g, name, - aix.FromInline( - ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), - ), - aix.WithSessionStore(mustStore(name)), - aix.WithDescription[any]("Sarcastic pirate (inline-defined prompt)"), - ) -} - -// definePromptAgent demonstrates DefineAgent with aix.FromPrompt. The -// prompt is loaded from ./prompts/.prompt by genkit's prompt -// registry. Defining the prompt in a file lets you tune model, config, -// schema, and template independently of the Go code — useful when prompt -// authors are not the same people writing the agent wiring. -// -// FromPrompt's argument is the default input passed to the prompt's -// Render on every turn; the inline-prompt variant has no per-turn input -// of its own. -func definePromptAgent(g *genkit.Genkit) *aix.Agent[any] { - const name = "chef" - return genkit.DefineAgent(g, name, - aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), - aix.WithSessionStore(mustStore(name)), - aix.WithDescription[any]("Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)"), - ) -} - -// defineCustomAgent demonstrates DefineCustomAgent. The per-turn function -// is fully under your control: it picks the model, manages the message -// list, streams chunks back to the client, and decides what to put in the -// final result. Use this form when the prompt-backed agent loop doesn't -// fit (e.g. you want to pre/post-process every turn, swap models -// dynamically, or wire up custom tool plumbing). -// -// Even with full control over the loop, the framework still owns session -// state, snapshot writes, and the detach lifecycle. -func defineCustomAgent(g *genkit.Genkit) *aix.Agent[any] { - const name = "coder" - return genkit.DefineCustomAgent(g, name, - func(ctx context.Context, resp aix.Responder, sess *aix.SessionRunner[any]) (*aix.AgentResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentInput) (*aix.TurnResult, error) { - for chunk, err := range genkit.GenerateStream(ctx, g, - ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - })), - ai.WithSystem("You are a senior software engineer. Answer briefly. Use fenced code blocks when showing code."), - ai.WithMessages(sess.Messages()...), - ) { - if err != nil { - return nil, err - } - if chunk.Done { - sess.AddMessages(chunk.Response.Message) - // Report how the turn ended so the framework can - // forward it on the TurnEnd chunk and persist it - // on the snapshot. - return &aix.TurnResult{ - FinishReason: aix.AgentFinishReason(chunk.Response.FinishReason), - }, nil - } - resp.SendModelChunk(chunk.Chunk) - } - return nil, nil - }); err != nil { - return nil, err - } - return sess.Result(), nil - }, - aix.WithSessionStore(mustStore(name)), - aix.WithDescription[any]("Concise code helper (custom per-turn loop)"), - ) -} +// flashModel is the model shared by the agents in this sample: +// gemini-flash-latest with thinking disabled for snappy, low-cost turns. +// genkit copies request config per call rather than mutating it, so one shared +// reference is safe across all the agents. +var flashModel = googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ThinkingBudget: genai.Ptr[int32](0)}, +}) // mustStore creates a FileSessionStore rooted at the per-agent dir under // ./.genkit/snapshots/, or exits the process on failure. Used during diff --git a/go/samples/basic-agents/orchestrator.go b/go/samples/basic-agents/orchestrator.go new file mode 100644 index 0000000000..76d78bee8a --- /dev/null +++ b/go/samples/basic-agents/orchestrator.go @@ -0,0 +1,90 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" + middlewarex "github.com/firebase/genkit/go/plugins/middleware/exp" +) + +// defineOrchestratorAgent demonstrates the experimental Agents middleware: an +// orchestrator that delegates to specialized sub-agents through per-agent tools. +// +// The middleware injects one delegation tool per sub-agent (delegate_to_), +// lists the sub-agents and their descriptions in the system prompt, and runs the +// chosen sub-agent when the orchestrator model calls its tool. It mirrors the +// JS "orchestrator" sample. +// +// The two sub-agents (researcher, engineer) are client-managed (no session +// store): each delegation runs them one-shot and leaves no snapshots behind, so +// only the orchestrator appears in the CLI. Both use the Artifacts middleware so +// they can persist output as named session artifacts; with +// ArtifactStrategySession those artifacts are merged into the orchestrator's +// session, and Artifacts{Readonly: true} gives the orchestrator a read_artifact +// tool to review them before answering. +func defineOrchestratorAgent(g *genkit.Genkit) *aix.Agent[any] { + researcher := genkit.DefineAgent(g, "researcher", + aix.FromInline( + ai.WithModel(flashModel), + ai.WithSystem("You are a thorough research assistant. Answer the question "+ + "concisely, then call write_artifact to save your findings as a named "+ + "markdown artifact (for example \"findings.md\")."), + ai.WithUse(&middlewarex.Artifacts{}), + ), + aix.WithDescription[any]("Researches a topic and summarizes well-sourced findings."), + ) + + engineer := genkit.DefineAgent(g, "engineer", + aix.FromInline( + ai.WithModel(flashModel), + ai.WithSystem("You are an expert programmer. Write clean, well-commented code, "+ + "then call write_artifact to save it as a named file artifact (for "+ + "example \"main.go\")."), + ai.WithUse(&middlewarex.Artifacts{}), + ), + aix.WithDescription[any]("Writes and explains code, producing file artifacts."), + ) + + return genkit.DefineAgent(g, "orchestrator", + aix.FromInline( + ai.WithModel(flashModel), + ai.WithSystem("You are a project coordinator. Analyze the user's request and "+ + "delegate to the appropriate sub-agent using its delegation tool. If a "+ + "request needs both research and code, delegate to each in turn. After "+ + "the sub-agents respond you may call read_artifact to review their work, "+ + "then synthesize a final answer for the user."), + ai.WithUse( + // One delegation tool per sub-agent. Descriptions are + // auto-discovered from each agent (set via WithDescription and + // captured by Ref). historyLength forwards recent turns to the + // client-managed sub-agents; artifactStrategy "session" merges + // their artifacts into this session. + &middlewarex.Agents{ + Agents: []aix.AgentRef{researcher.Ref(), engineer.Ref()}, + MaxDelegations: 5, + HistoryLength: 4, + ArtifactStrategy: middlewarex.ArtifactStrategySession, + }, + // Read-only artifact access: the orchestrator reviews sub-agent + // artifacts but does not produce its own. + &middlewarex.Artifacts{Readonly: true}, + ), + ), + aix.WithSessionStore(mustStore("orchestrator")), + aix.WithDescription[any]("Coordinates research and coding sub-agents via the agents middleware"), + ) +} diff --git a/go/samples/basic-agents/pirate.go b/go/samples/basic-agents/pirate.go new file mode 100644 index 0000000000..9ec0b0dae9 --- /dev/null +++ b/go/samples/basic-agents/pirate.go @@ -0,0 +1,39 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" +) + +// defineInlineAgent demonstrates DefineAgent with aix.FromInline. The +// prompt is declared right next to the agent definition; the registered +// prompt and the agent share a name. Each turn the framework renders the +// prompt, appends the conversation history, calls the model, and updates +// session state. This is the shortest path from "I want a chat agent" to +// a working one. +func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any] { + const name = "pirate" + return genkit.DefineAgent(g, name, + aix.FromInline( + ai.WithModel(flashModel), + ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), + ), + aix.WithSessionStore(mustStore(name)), + aix.WithDescription[any]("Sarcastic pirate (inline-defined prompt)"), + ) +} From 4e793026529b82dd8864b3843b922412ee7df776 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 23 Jun 2026 12:02:59 -0700 Subject: [PATCH 130/141] feat(go/ai/exp): rework agent sources into InlinePrompt/SameNamedPrompt/NamedPrompt Replace the FromInline/FromPrompt agent-source constructors with three clearer forms: - InlinePrompt(opts...) defines the prompt inline (was FromInline) - SameNamedPrompt() references the prompt named like the agent - NamedPrompt(name, input) references any registered prompt by name, rendered with an input supplied from code NamedPrompt decouples the prompt's lookup name from the agent's name, so a single prompt can back many agents with different inputs. The old FromPrompt(defaultInput...) variadic-of-one is gone; per-turn input now rides on NamedPrompt. Updates the wrapper docs, both samples (chef's personality moves to the .prompt frontmatter default), the README, and the tests, and adds TestPromptAgent_NamedPromptSharedAcrossAgents covering the shared-prompt path. Breaking change to the experimental ai/exp API. --- go/README.md | 26 ++++-- go/ai/exp/agent.go | 29 ++++--- go/ai/exp/agent_test.go | 91 ++++++++++++++++++--- go/ai/exp/source.go | 63 ++++++++------ go/genkit/exp/routes_test.go | 8 +- go/genkit/genkit.go | 25 +++--- go/genkit/servers_test.go | 6 +- go/samples/basic-agents-server/main.go | 4 +- go/samples/basic-agents/main.go | 26 +++--- go/samples/basic-agents/prompts/chef.prompt | 2 + 10 files changed, 193 insertions(+), 87 deletions(-) diff --git a/go/README.md b/go/README.md index 356cfd409d..ea0b1f1425 100644 --- a/go/README.md +++ b/go/README.md @@ -85,7 +85,7 @@ The agent API is experimental: it lives in `github.com/firebase/genkit/go/ai/exp ### Define an Agent -The shortest path is a prompt-backed agent with an inline prompt and a session store. `aix.FromInline` declares the prompt right next to the agent; the store persists each turn so the conversation can resume later: +The shortest path is a prompt-backed agent with an inline prompt and a session store. `aix.InlinePrompt` declares the prompt right next to the agent; the store persists each turn so the conversation can resume later: ```go import ( @@ -96,7 +96,7 @@ import ( ) chatAgent := genkit.DefineAgent(g, "chat", - aix.FromInline( + aix.InlinePrompt( ai.WithModelName("googleai/gemini-flash-latest"), ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), ), @@ -143,7 +143,7 @@ fmt.Println(out.Message.Text()) ### Load the Prompt from a File -`aix.FromPrompt` backs the agent with a prompt already in the registry, including one loaded from a `.prompt` file. Prompt authors can tune the model, config, and template without touching the Go wiring. The agent and its prompt share a name: +`aix.SameNamedPrompt` backs the agent with the prompt registered under the agent's name, including one loaded from a `.prompt` file. Prompt authors can tune the model, config, template, and default input without touching the Go wiring: ```yaml # prompts/chat.prompt @@ -151,6 +151,8 @@ fmt.Println(out.Message.Text()) model: googleai/gemini-flash-latest input: schema: ChatInput + default: + personality: a Michelin-starred chef --- {{role "system"}} You are {{personality}}. Keep responses concise. @@ -164,13 +166,27 @@ type ChatInput struct { // Register the schema so the .prompt file can reference it by name. genkit.DefineSchemaFor[ChatInput](g) -// FromPrompt's argument is the default input rendered on every turn. +// Agent "chat" renders ./prompts/chat.prompt every turn. chatAgent := genkit.DefineAgent(g, "chat", - aix.FromPrompt(ChatInput{Personality: "a Michelin-starred chef"}), + aix.SameNamedPrompt(), aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), ) ``` +To back several agents with one shared prompt, reference it by name with `aix.NamedPrompt` and give each its own input. The prompt name need not match the agent name: + +```go +for _, p := range []struct{ name, persona string }{ + {"pirate", "a sarcastic pirate"}, + {"chef", "a Michelin-starred chef"}, +} { + genkit.DefineAgent(g, p.name, + aix.NamedPrompt("chat", ChatInput{Personality: p.persona}), + aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), + ) +} +``` + [See full example](samples/basic-agents) ### Custom Turn Loops diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 209c0426e0..6ad84e04ba 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -618,11 +618,12 @@ func (a *Agent[State]) ConnectJSON(ctx context.Context, opts *api.BidiJSONOption // // source selects how the prompt is backed: // -// - [FromInline] defines the prompt inline from a set of +// - [InlinePrompt] defines the prompt inline from a set of // [ai.PromptOption] values; the prompt is registered under name. -// - [FromPrompt] references an existing prompt registered with the -// registry under name (e.g. one defined via [ai.DefinePrompt] or -// loaded from a .prompt file). +// - [SameNamedPrompt] references an existing prompt registered under name +// (e.g. one defined via [ai.DefinePrompt] or loaded from a .prompt file). +// - [NamedPrompt] references any registered prompt by name with an input +// supplied from code, so a single prompt can back many agents. // // State is inferred from the typed agent options (e.g. // [WithSessionStore], [WithStateTransform]); pass an explicit [State] only @@ -640,15 +641,19 @@ func DefineAgent[State any]( case inlineSource: prompt := ai.DefinePrompt(r, name, s.opts...) return DefineCustomAgent(r, name, agentLoop[State](r, prompt, nil), opts...) - case promptSource: - prompt := ai.LookupPrompt(r, name) + case existingSource: + promptName := s.name + if promptName == "" { + promptName = name // SameNamedPrompt: resolve by the agent's own name + } + prompt := ai.LookupPrompt(r, promptName) if prompt == nil { - panic(fmt.Sprintf("DefineAgent %q: prompt %q not found", name, name)) + panic(fmt.Sprintf("DefineAgent %q: prompt %q not found", name, promptName)) } - if _, err := prompt.Render(context.Background(), s.defaultInput); err != nil { - panic(fmt.Sprintf("DefineAgent %q: defaultInput does not satisfy prompt schema: %v", name, err)) + if _, err := prompt.Render(context.Background(), s.input); err != nil { + panic(fmt.Sprintf("DefineAgent %q: prompt input does not satisfy prompt schema: %v", name, err)) } - return DefineCustomAgent(r, name, agentLoop[State](r, prompt, s.defaultInput), opts...) + return DefineCustomAgent(r, name, agentLoop[State](r, prompt, s.input), opts...) default: panic(fmt.Sprintf("DefineAgent %q: unknown source type %T", name, source)) } @@ -2058,7 +2063,7 @@ func validateUserMessage(m *ai.Message) error { // interrupted call. The whole history is searched (every model message), not // just the last turn. On a violation it returns an INVALID_ARGUMENT error. // -// The prompt-backed agent loop ([FromPrompt]) calls this automatically. A +// The prompt-backed agent loop ([DefineAgent]) calls this automatically. A // custom agent ([DefineCustomAgent]) that accepts an [AgentInput.Resume] from // untrusted callers should call it before forwarding the payload to the model: // @@ -2145,7 +2150,7 @@ func toolRefSuffix(ref string) string { // with streaming, and updates the session. // // defaultInput is the prompt input passed to Render on every turn. It is -// nil for inline-defined prompts ([FromInline]), which take no per-turn +// nil for inline-defined prompts ([InlinePrompt]), which take no per-turn // input. func agentLoop[State any](r api.Registry, prompt ai.Prompt, defaultInput any) AgentFunc[State] { return func(ctx context.Context, resp Responder, sess *SessionRunner[State]) (*AgentResult, error) { diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 83d7daeee8..60f6a886bb 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -1516,6 +1516,75 @@ func setupPromptTestRegistry(t *testing.T) *registry.Registry { return reg } +// personaInput is the input schema for the shared prompt exercised by +// TestPromptAgent_NamedPromptSharedAcrossAgents. +type personaInput struct { + Personality string `json:"personality"` +} + +// TestPromptAgent_NamedPromptSharedAcrossAgents verifies that NamedPrompt +// backs several agents with a single prompt, rendering each with its own +// input, and that an agent's name is independent of the prompt it references. +func TestPromptAgent_NamedPromptSharedAcrossAgents(t *testing.T) { + ctx := context.Background() + reg := registry.New() + ai.ConfigureFormats(reg) + + var mu sync.Mutex + var renderedSystems []string + ai.DefineModel(reg, "test/capture", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + mu.Lock() + for _, m := range req.Messages { + if m.Role == ai.RoleSystem { + renderedSystems = append(renderedSystems, m.Text()) + } + } + mu.Unlock() + return &ai.ModelResponse{Request: req, Message: ai.NewModelTextMessage("ok")}, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + + // One shared prompt with a personality variable, registered under a name + // that matches no agent. + ai.DefinePrompt(reg, "sharedChat", + ai.WithModelName("test/capture"), + ai.WithInputType(personaInput{}), + ai.WithSystem("You are {{personality}}."), + ) + + // Two agents, different names, the one prompt, different inputs. + pirate := DefineAgent[testState](reg, "pirate", + NamedPrompt("sharedChat", personaInput{Personality: "a pirate"})) + chef := DefineAgent[testState](reg, "chef", + NamedPrompt("sharedChat", personaInput{Personality: "a chef"})) + + // The agents register under their own names, not the prompt's. + if pirate.Name() != "pirate" { + t.Errorf("pirate.Name() = %q, want %q", pirate.Name(), "pirate") + } + if chef.Name() != "chef" { + t.Errorf("chef.Name() = %q, want %q", chef.Name(), "chef") + } + + if _, err := pirate.RunText(ctx, "hi"); err != nil { + t.Fatalf("pirate RunText: %v", err) + } + if _, err := chef.RunText(ctx, "hi"); err != nil { + t.Fatalf("chef RunText: %v", err) + } + + // Each agent rendered the shared prompt with its own personality input. + joined := strings.Join(renderedSystems, " | ") + if !strings.Contains(joined, "You are a pirate.") { + t.Errorf("pirate personality not rendered; system prompts = %q", joined) + } + if !strings.Contains(joined, "You are a chef.") { + t.Errorf("chef personality not rendered; system prompts = %q", joined) + } +} + func TestPromptAgent_Basic(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) @@ -1525,7 +1594,7 @@ func TestPromptAgent_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineAgent[testState](reg, "testPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "testPrompt", SameNamedPrompt()) conn, err := af.Connect(ctx) if err != nil { @@ -1600,7 +1669,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - af := DefineAgent[testState](reg, "historyPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "historyPrompt", SameNamedPrompt()) conn, err := af.Connect(ctx) if err != nil { @@ -1674,7 +1743,7 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineAgent[testState](reg, "snapPrompt", FromPrompt(), + af := DefineAgent[testState](reg, "snapPrompt", SameNamedPrompt(), WithSessionStore(store), ) @@ -1799,7 +1868,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ai.WithTools(ai.ToolName("greet"), ai.ToolName("farewell")), ) - af := DefineAgent[testState](reg, "toolPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "toolPrompt", SameNamedPrompt()) conn, err := af.Connect(ctx) if err != nil { @@ -2004,7 +2073,7 @@ func TestPromptAgent_RunText(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineAgent[testState](reg, "runTextPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "runTextPrompt", SameNamedPrompt()) response, err := af.RunText(ctx, "hello") if err != nil { @@ -2029,7 +2098,7 @@ func TestPromptAgent_RejectsInvalidInputMessage(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) ai.DefinePrompt(reg, "rejectPrompt", ai.WithModelName("test/echo")) - af := DefineAgent[testState](reg, "rejectPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "rejectPrompt", SameNamedPrompt()) tests := []struct { name string @@ -2188,7 +2257,7 @@ func TestPromptAgent_RejectsResumeForUnrequestedTool(t *testing.T) { ai.DefineGenerateAction(ctx, reg) ai.DefinePrompt(reg, "plainPrompt", ai.WithModelName("test/plain")) - af := DefineAgent[testState](reg, "plainPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "plainPrompt", SameNamedPrompt()) conn, err := af.Connect(ctx) if err != nil { @@ -5042,7 +5111,7 @@ func TestPromptAgent_ForwardsFinishReason(t *testing.T) { ai.DefineGenerateAction(ctx, reg) ai.DefinePrompt(reg, "lengthPrompt", ai.WithModelName("test/length")) - af := DefineAgent[testState](reg, "lengthPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "lengthPrompt", SameNamedPrompt()) conn, err := af.Connect(ctx) if err != nil { @@ -5445,7 +5514,7 @@ func TestPromptAgent_ForwardsInterruptedFinishReason(t *testing.T) { ai.WithTools(interruptTool), ) - af := DefineAgent[testState](reg, "interruptPrompt", FromPrompt()) + af := DefineAgent[testState](reg, "interruptPrompt", SameNamedPrompt()) conn, err := af.Connect(ctx) if err != nil { @@ -6413,7 +6482,7 @@ func TestPromptAgent_InlineMessages_DoesNotMutateSharedMetadata(t *testing.T) { shared := ai.NewModelTextMessage("inline context message") shared.Metadata = map[string]any{"origin": "config"} - af := DefineAgent[testState](reg, "inlineMetaPrompt", FromInline( + af := DefineAgent[testState](reg, "inlineMetaPrompt", InlinePrompt( ai.WithModelName("test/echo"), ai.WithMessages(shared), )) @@ -6446,7 +6515,7 @@ func TestPromptAgent_InlineMessages_ConcurrentInvocations(t *testing.T) { shared := ai.NewModelTextMessage("inline context message") shared.Metadata = map[string]any{"origin": "config"} - af := DefineAgent[testState](reg, "inlineConcurrentPrompt", FromInline( + af := DefineAgent[testState](reg, "inlineConcurrentPrompt", InlinePrompt( ai.WithModelName("test/echo"), ai.WithMessages(shared), )) diff --git a/go/ai/exp/source.go b/go/ai/exp/source.go index d30475f9ca..e6afb53127 100644 --- a/go/ai/exp/source.go +++ b/go/ai/exp/source.go @@ -19,19 +19,21 @@ package exp import "github.com/firebase/genkit/go/ai" // AgentSource selects the prompt backing a prompt-based agent. Pass an -// AgentSource as the third argument to [DefineAgent]. There are two +// AgentSource as the third argument to [DefineAgent]. There are three // forms: // -// - [FromInline] defines the prompt inline from a set of +// - [InlinePrompt] defines the prompt inline from a set of // [ai.PromptOption] values; the prompt is registered with the // registry under the agent's name. -// - [FromPrompt] references an existing prompt registered with the -// registry under the same name as the agent (e.g. one defined via -// [ai.DefinePrompt] or loaded from a .prompt file). +// - [SameNamedPrompt] references the prompt already registered under the +// agent's own name (e.g. one defined via [ai.DefinePrompt] or loaded +// from a .prompt file). +// - [NamedPrompt] references any registered prompt by name and renders it +// with an input supplied from code, so a single prompt can back many +// agents with different inputs. // -// The agent and its backing prompt always share a name; if you need -// the lookup name to differ from the agent name, define a custom agent -// via [DefineCustomAgent] instead. +// For full control over the per-turn loop, define a custom agent via +// [DefineCustomAgent] instead. type AgentSource interface { isAgentSource() } @@ -42,30 +44,37 @@ type inlineSource struct { func (inlineSource) isAgentSource() {} -// FromInline defines the agent's prompt inline from the given options. -// The prompt is registered with the registry under the agent's name. -func FromInline(opts ...ai.PromptOption) AgentSource { +// InlinePrompt defines the agent's prompt inline from the given options. +// The prompt is registered with the registry under the agent's name. To +// give the template a default render input, include [ai.WithInputType] +// among the options. +func InlinePrompt(opts ...ai.PromptOption) AgentSource { return inlineSource{opts: opts} } -type promptSource struct { - defaultInput any +type existingSource struct { + name string // "" => resolve by the agent's own name + input any } -func (promptSource) isAgentSource() {} +func (existingSource) isAgentSource() {} -// FromPrompt references an existing prompt registered with the -// registry under the same name as the agent (e.g. one defined via -// [ai.DefinePrompt] or loaded from a .prompt file). +// SameNamedPrompt references the prompt registered under the agent's own +// name (e.g. one defined via [ai.DefinePrompt] or loaded from a .prompt +// file). The prompt renders with its own default input each turn. It is +// shorthand for NamedPrompt(, nil). +func SameNamedPrompt() AgentSource { + return existingSource{} +} + +// NamedPrompt references the prompt registered under name, rendered with +// input on every turn (pass nil for the prompt's own default input). name +// need not match the agent's name, so a single prompt can back many agents +// with different inputs. // -// defaultInput, if provided, is the input passed to the prompt's -// Render on every turn. Call FromPrompt() with no arguments when the -// prompt takes no input. Only the first argument is used; any -// additional arguments are ignored. -func FromPrompt(defaultInput ...any) AgentSource { - var input any - if len(defaultInput) > 0 { - input = defaultInput[0] - } - return promptSource{defaultInput: input} +// input is rendered through the prompt once at definition time as a smoke +// check, so an input that fails the prompt's schema panics there rather +// than on the first invocation. +func NamedPrompt(name string, input any) AgentSource { + return existingSource{name: name, input: input} } diff --git a/go/genkit/exp/routes_test.go b/go/genkit/exp/routes_test.go index 37adea95df..c649b0fa11 100644 --- a/go/genkit/exp/routes_test.go +++ b/go/genkit/exp/routes_test.go @@ -65,10 +65,10 @@ func newRouteTestGenkit(t *testing.T) *genkit.Genkit { if err != nil { t.Fatal(err) } - genkit.DefineAgent(g, "serverChat", aix.FromInline(ai.WithModelName("test/echo")), + genkit.DefineAgent(g, "serverChat", aix.InlinePrompt(ai.WithModelName("test/echo")), aix.WithSessionStore(store), ) - genkit.DefineAgent[any](g, "clientChat", aix.FromInline(ai.WithModelName("test/echo"))) + genkit.DefineAgent[any](g, "clientChat", aix.InlinePrompt(ai.WithModelName("test/echo"))) genkit.DefineFlow(g, "greet", func(ctx context.Context, name string) (string, error) { return "hi " + name, nil }) @@ -112,8 +112,8 @@ func TestAgentRoutes_PicksOneAgentAndMirrorsCapabilities(t *testing.T) { if err != nil { t.Fatal(err) } - server := genkit.DefineAgent(g, "srv", aix.FromInline(ai.WithModelName("test/echo")), aix.WithSessionStore(store)) - client := genkit.DefineAgent[any](g, "cli", aix.FromInline(ai.WithModelName("test/echo"))) + server := genkit.DefineAgent(g, "srv", aix.InlinePrompt(ai.WithModelName("test/echo")), aix.WithSessionStore(store)) + client := genkit.DefineAgent[any](g, "cli", aix.InlinePrompt(ai.WithModelName("test/echo"))) if got, want := routeKeys(AgentRoutes(server)), []string{ "POST /agents/srv", diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index b541a1c91a..2ae9b642ca 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -420,11 +420,13 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // // source selects how the prompt is backed: // -// - [aix.FromInline] defines the prompt inline from a set of +// - [aix.InlinePrompt] defines the prompt inline from a set of // [ai.PromptOption] values; the prompt is registered under name. -// - [aix.FromPrompt] references an existing prompt registered with -// the registry under name (e.g. one defined via [DefinePrompt] or -// loaded from a .prompt file). +// - [aix.SameNamedPrompt] references an existing prompt registered under +// name (e.g. one defined via [DefinePrompt] or loaded from a .prompt +// file). +// - [aix.NamedPrompt] references any registered prompt by name with an +// input supplied from code, so a single prompt can back many agents. // // The State type parameter is inferred from the typed agent options // (e.g. [aix.WithSessionStore], [aix.WithStateTransform]); pass an explicit @@ -446,21 +448,21 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // Example (inline prompt): // // chatAgent := genkit.DefineAgent(g, "chat", -// aix.FromInline( -// ai.WithModelName("googleai/gemini-3-flash-preview"), +// aix.InlinePrompt( +// ai.WithModelName("googleai/gemini-flash-latest"), // ai.WithSystem("You are a helpful assistant."), // ), // aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), // ) // -// Example (existing prompt): +// Example (a shared .prompt file, parameterized per agent): // // type ChatInput struct { // Personality string `json:"personality"` // } // -// chatAgent := genkit.DefineAgent(g, "chat", -// aix.FromPrompt(ChatInput{Personality: "a sarcastic pirate"}), +// pirate := genkit.DefineAgent(g, "pirate", +// aix.NamedPrompt("chat", ChatInput{Personality: "a sarcastic pirate"}), // aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), // ) func DefineAgent[State any]( @@ -488,8 +490,9 @@ func DefineAgent[State any]( // via [Handler], with companion actions on [aix.Agent.GetSnapshotAction] // and [aix.Agent.AbortSnapshotAction]. // -// For agents backed by a prompt, use [DefineAgent] with [aix.FromInline] -// (inline prompt) or [aix.FromPrompt] (existing prompt) instead. +// For agents backed by a prompt, use [DefineAgent] with [aix.InlinePrompt] +// (inline prompt), [aix.SameNamedPrompt], or [aix.NamedPrompt] (existing +// prompt) instead. // // # Options // diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 08521a0056..e6fa1e288a 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -758,13 +758,13 @@ func TestHandlerAgent(t *testing.T) { }, nil }) - DefineAgent[any](g, "agentClient", aix.FromInline(ai.WithModelName("test/echo"))) + DefineAgent[any](g, "agentClient", aix.InlinePrompt(ai.WithModelName("test/echo"))) store, err := localstore.NewFileSessionStore[any](t.TempDir()) if err != nil { t.Fatal(err) } - DefineAgent(g, "agentServer", aix.FromInline(ai.WithModelName("test/echo")), + DefineAgent(g, "agentServer", aix.InlinePrompt(ai.WithModelName("test/echo")), aix.WithSessionStore(store), ) @@ -969,7 +969,7 @@ func TestHandlerAgentRef(t *testing.T) { if err != nil { t.Fatal(err) } - agent := DefineAgent(g, "agentRef", aix.FromInline(ai.WithModelName("test/echo")), + agent := DefineAgent(g, "agentRef", aix.InlinePrompt(ai.WithModelName("test/echo")), aix.WithSessionStore(store), ) diff --git a/go/samples/basic-agents-server/main.go b/go/samples/basic-agents-server/main.go index e38d292c7f..6099b3f4d6 100644 --- a/go/samples/basic-agents-server/main.go +++ b/go/samples/basic-agents-server/main.go @@ -135,7 +135,7 @@ func main() { log.Fatalf("creating session store: %v", err) } genkit.DefineAgent(g, "chat", - aix.FromInline( + aix.InlinePrompt( ai.WithModel(model), ai.WithSystem("You are a helpful travel assistant. Keep responses to a couple of sentences."), ), @@ -146,7 +146,7 @@ func main() { // the full conversation state and the client round-trips it on the next // request. This suits deployments where the server must stay stateless. genkit.DefineAgent[any](g, "statelessChat", - aix.FromInline( + aix.InlinePrompt( ai.WithModel(model), ai.WithSystem("You are a helpful travel assistant. Keep responses to a couple of sentences."), ), diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index d3ffad4db2..55b3d5d8cd 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -15,10 +15,10 @@ // This sample demonstrates Genkit's agent APIs by defining three agents in // three different styles and exposing all of them through a single CLI: // -// - "pirate" uses DefineAgent + aix.FromInline. The prompt is declared +// - "pirate" uses DefineAgent + aix.InlinePrompt. The prompt is declared // inline next to the agent. -// - "chef" uses DefineAgent + aix.FromPrompt. The prompt is loaded from -// ./prompts/chef.prompt by the agent's name. +// - "chef" uses DefineAgent + aix.SameNamedPrompt. The prompt is loaded +// from ./prompts/chef.prompt by the agent's name. // - "coder" uses DefineCustomAgent. The per-turn loop (model selection, // history management, streaming) is wired by hand. // @@ -97,7 +97,7 @@ func main() { } } -// defineInlineAgent demonstrates DefineAgent with aix.FromInline. The +// defineInlineAgent demonstrates DefineAgent with aix.InlinePrompt. The // prompt is declared right next to the agent definition; the registered // prompt and the agent share a name. Each turn the framework renders the // prompt, appends the conversation history, calls the model, and updates @@ -106,7 +106,7 @@ func main() { func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any] { const name = "pirate" return genkit.DefineAgent(g, name, - aix.FromInline( + aix.InlinePrompt( ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), @@ -119,19 +119,21 @@ func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any] { ) } -// definePromptAgent demonstrates DefineAgent with aix.FromPrompt. The +// definePromptAgent demonstrates DefineAgent with aix.SameNamedPrompt. The // prompt is loaded from ./prompts/.prompt by genkit's prompt // registry. Defining the prompt in a file lets you tune model, config, -// schema, and template independently of the Go code — useful when prompt -// authors are not the same people writing the agent wiring. +// schema, template, and default input independently of the Go code, which +// is useful when prompt authors are not the people writing the agent wiring. // -// FromPrompt's argument is the default input passed to the prompt's -// Render on every turn; the inline-prompt variant has no per-turn input -// of its own. +// SameNamedPrompt references the prompt registered under the agent's own +// name and renders it with the prompt's own default input each turn (here, +// the personality set in chef.prompt's frontmatter). To supply an input +// from code, or to back several agents with one shared prompt, use +// aix.NamedPrompt(name, input) instead. func definePromptAgent(g *genkit.Genkit) *aix.Agent[any] { const name = "chef" return genkit.DefineAgent(g, name, - aix.FromPrompt(ChatPromptInput{Personality: "a Michelin-starred chef who loves explaining technique"}), + aix.SameNamedPrompt(), aix.WithSessionStore(mustStore(name)), aix.WithDescription[any]("Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)"), ) diff --git a/go/samples/basic-agents/prompts/chef.prompt b/go/samples/basic-agents/prompts/chef.prompt index c40edfcbaa..283c4c68ab 100644 --- a/go/samples/basic-agents/prompts/chef.prompt +++ b/go/samples/basic-agents/prompts/chef.prompt @@ -5,5 +5,7 @@ config: thinkingBudget: 0 input: schema: ChatPromptInput + default: + personality: a Michelin-starred chef who loves explaining technique --- You are {{personality}}. Keep responses concise. From c15b1a61a052c23a921a08545e761935a9c55312 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 23 Jun 2026 14:17:49 -0700 Subject: [PATCH 131/141] feat(go/ai/exp): make DefineAgent inline-only, add DefinePromptAgent Split the prompt-backed agent constructors so each has one clear job: - DefineAgent(r, name, prompt, opts...) defines the prompt inline; the prompt is an InlinePrompt, a []ai.PromptOption slice passed positionally - DefinePromptAgent(r, name, opts...) sources a prompt from the registry, defaulting to the agent's own name DefinePromptAgent uses the same-named prompt by default; WithNamedPrompt(name, input) points it at a different registered prompt rendered with a code-supplied input, so one prompt can back many agents. The prompt source is split across a compile-time-validated option set mirroring ai/option.go: the shared options (WithSessionStore, WithStateTransform, WithDescription) are AgentOptions valid on every constructor, while WithNamedPrompt is a PromptAgentOption accepted only by DefinePromptAgent. Passing it to DefineAgent or DefineCustomAgent fails to compile. Making the inline prompt a required positional argument means an inline agent cannot be defined without one. Removes the AgentSource abstraction and the SameNamedPrompt/NamedPrompt sources. Updates the wrapper docs, both samples, and the tests. Breaking change to the experimental ai/exp API. --- go/ai/exp/agent.go | 115 +++++++++++++++------- go/ai/exp/agent_test.go | 131 +++++++++++++++++++++---- go/ai/exp/inline.go | 37 +++++++ go/ai/exp/option.go | 77 ++++++++++++++- go/ai/exp/source.go | 80 --------------- go/genkit/exp/routes_test.go | 8 +- go/genkit/genkit.go | 83 +++++++++++----- go/genkit/servers_test.go | 6 +- go/samples/basic-agents-server/main.go | 8 +- go/samples/basic-agents/main.go | 33 ++++--- 10 files changed, 390 insertions(+), 188 deletions(-) create mode 100644 go/ai/exp/inline.go delete mode 100644 go/ai/exp/source.go diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 6ad84e04ba..004a3fdcc9 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -612,51 +612,87 @@ func (a *Agent[State]) ConnectJSON(ctx context.Context, opts *api.BidiJSONOption return a.action.ConnectJSON(ctx, opts) } -// DefineAgent defines a prompt-backed agent and registers it. Each turn -// renders the agent's prompt, appends conversation history, calls the -// model with streaming, and updates session state. +// DefineAgent defines an agent backed by an inline prompt and registers it. The +// prompt is defined from prompt's [ai.PromptOption] values and registered under +// the agent's name; each turn renders it, appends conversation history, calls +// the model with streaming, and updates session state. // -// source selects how the prompt is backed: +// The prompt is an [InlinePrompt], a list of [ai.PromptOption] values: // -// - [InlinePrompt] defines the prompt inline from a set of -// [ai.PromptOption] values; the prompt is registered under name. -// - [SameNamedPrompt] references an existing prompt registered under name -// (e.g. one defined via [ai.DefinePrompt] or loaded from a .prompt file). -// - [NamedPrompt] references any registered prompt by name with an input -// supplied from code, so a single prompt can back many agents. +// agent := DefineAgent(r, "pirate", +// InlinePrompt{ +// ai.WithModelName("googleai/gemini-flash-latest"), +// ai.WithSystem("You are a sarcastic pirate."), +// }, +// WithSessionStore(store), +// ) // -// State is inferred from the typed agent options (e.g. -// [WithSessionStore], [WithStateTransform]); pass an explicit [State] only -// when no typed option is provided. A typed option that disagrees with -// the inferred State fails at compile time. +// State is inferred from the typed agent options (e.g. [WithSessionStore], +// [WithStateTransform]); pass an explicit [State] only when no typed option is +// provided. A typed option that disagrees with the inferred State fails at +// compile time. // -// For full control over the per-turn loop, use [DefineCustomAgent]. +// To back an agent with a prompt already in the registry (e.g. one from a +// .prompt file), use [DefinePromptAgent]. For full control over the per-turn +// loop, use [DefineCustomAgent]. func DefineAgent[State any]( r api.Registry, name string, - source AgentSource, + prompt InlinePrompt, opts ...AgentOption[State], ) *Agent[State] { - switch s := source.(type) { - case inlineSource: - prompt := ai.DefinePrompt(r, name, s.opts...) - return DefineCustomAgent(r, name, agentLoop[State](r, prompt, nil), opts...) - case existingSource: - promptName := s.name - if promptName == "" { - promptName = name // SameNamedPrompt: resolve by the agent's own name - } - prompt := ai.LookupPrompt(r, promptName) - if prompt == nil { - panic(fmt.Sprintf("DefineAgent %q: prompt %q not found", name, promptName)) - } - if _, err := prompt.Render(context.Background(), s.input); err != nil { - panic(fmt.Sprintf("DefineAgent %q: prompt input does not satisfy prompt schema: %v", name, err)) + p := ai.DefinePrompt(r, name, prompt...) + return DefineCustomAgent(r, name, agentLoop[State](r, p, nil), opts...) +} + +// DefinePromptAgent defines a prompt-backed agent and registers it, sourcing +// its prompt from the registry by name. Each turn renders the prompt, appends +// conversation history, calls the model with streaming, and updates session +// state, exactly like [DefineAgent]. +// +// By default the agent uses the prompt registered under its own name (e.g. one +// defined via [ai.DefinePrompt] or loaded from a .prompt file), so no source +// option is required. Pass [WithNamedPrompt] to reference a differently named +// prompt and supply its render input from code, so a single prompt can back +// many agents. +// +// It is the registry-backed counterpart of [DefineAgent]: where [DefineAgent] +// defines the prompt inline, DefinePromptAgent points at a prompt already in +// the registry. The prompt source is a typed option ([WithNamedPrompt]) rather +// than a positional argument, so it composes with the other agent options +// ([WithSessionStore], [WithStateTransform], [WithDescription]) in a single +// variadic. For full control over the per-turn loop, use [DefineCustomAgent]. +// +// State is inferred from the typed agent options; pass an explicit [State] only +// when no typed option provides it (e.g. only [WithNamedPrompt] and +// [WithDescription], whose State cannot be deduced from their arguments). +func DefinePromptAgent[State any]( + r api.Registry, + name string, + opts ...PromptAgentOption[State], +) *Agent[State] { + cfg := &promptAgentOptions[State]{} + for _, opt := range opts { + if err := opt.applyPromptAgent(cfg); err != nil { + panic(fmt.Errorf("DefinePromptAgent %q: %w", name, err)) } - return DefineCustomAgent(r, name, agentLoop[State](r, prompt, s.input), opts...) - default: - panic(fmt.Sprintf("DefineAgent %q: unknown source type %T", name, source)) } + + promptName := cfg.promptName + if promptName == "" { + promptName = name // default: the prompt registered under the agent's own name + } + prompt := ai.LookupPrompt(r, promptName) + if prompt == nil { + panic(fmt.Sprintf("DefinePromptAgent %q: prompt %q not found", name, promptName)) + } + if _, err := prompt.Render(context.Background(), cfg.promptInput); err != nil { + panic(fmt.Sprintf("DefinePromptAgent %q: prompt input does not satisfy prompt schema: %v", name, err)) + } + + a := newCustomAgent(name, agentLoop[State](r, prompt, cfg.promptInput), &cfg.agentOptions) + a.Register(r) + return a } // NewCustomAgent creates an agent with full control over the conversation @@ -686,7 +722,18 @@ func NewCustomAgent[State any]( panic(fmt.Errorf("NewCustomAgent %q: %w", name, err)) } } + return newCustomAgent(name, fn, cfg) +} +// newCustomAgent builds (without registering) an agent from already-applied +// base options. It is the shared core of [NewCustomAgent] and the prompt-backed +// [DefinePromptAgent], which resolve their prompt source into an agentLoop fn +// and reuse the same base option set. +func newCustomAgent[State any]( + name string, + fn AgentFunc[State], + cfg *agentOptions[State], +) *Agent[State] { // Typed under ActionTypeAgent so agents surface as their own action // kind rather than as flows (genkit.ListAgents vs ListFlows). Built on // NewBidiAction so the agent capability metadata is set at construction diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 60f6a886bb..57f0984ec5 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -1522,7 +1522,7 @@ type personaInput struct { Personality string `json:"personality"` } -// TestPromptAgent_NamedPromptSharedAcrossAgents verifies that NamedPrompt +// TestPromptAgent_NamedPromptSharedAcrossAgents verifies that WithNamedPrompt // backs several agents with a single prompt, rendering each with its own // input, and that an agent's name is independent of the prompt it references. func TestPromptAgent_NamedPromptSharedAcrossAgents(t *testing.T) { @@ -1555,10 +1555,10 @@ func TestPromptAgent_NamedPromptSharedAcrossAgents(t *testing.T) { ) // Two agents, different names, the one prompt, different inputs. - pirate := DefineAgent[testState](reg, "pirate", - NamedPrompt("sharedChat", personaInput{Personality: "a pirate"})) - chef := DefineAgent[testState](reg, "chef", - NamedPrompt("sharedChat", personaInput{Personality: "a chef"})) + pirate := DefinePromptAgent[testState](reg, "pirate", + WithNamedPrompt[testState]("sharedChat", personaInput{Personality: "a pirate"})) + chef := DefinePromptAgent[testState](reg, "chef", + WithNamedPrompt[testState]("sharedChat", personaInput{Personality: "a chef"})) // The agents register under their own names, not the prompt's. if pirate.Name() != "pirate" { @@ -1585,6 +1585,101 @@ func TestPromptAgent_NamedPromptSharedAcrossAgents(t *testing.T) { } } +// TestDefinePromptAgent_DefaultAndNamed exercises the option-configured prompt +// agent: the default same-named lookup (no source option), the WithNamedPrompt +// override sourcing a shared prompt with a code-supplied input, and that the +// shared agent options (WithDescription) compose alongside the prompt source. +func TestDefinePromptAgent_DefaultAndNamed(t *testing.T) { + ctx := context.Background() + reg := registry.New() + ai.ConfigureFormats(reg) + + var mu sync.Mutex + var renderedSystems []string + ai.DefineModel(reg, "test/capture", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + mu.Lock() + for _, m := range req.Messages { + if m.Role == ai.RoleSystem { + renderedSystems = append(renderedSystems, m.Text()) + } + } + mu.Unlock() + return &ai.ModelResponse{Request: req, Message: ai.NewModelTextMessage("ok")}, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + + // A same-named prompt for the default lookup, and a shared parameterized + // prompt for the WithNamedPrompt override. + ai.DefinePrompt(reg, "chef", + ai.WithModelName("test/capture"), + ai.WithSystem("You are a chef."), + ) + ai.DefinePrompt(reg, "sharedChat", + ai.WithModelName("test/capture"), + ai.WithInputType(personaInput{}), + ai.WithSystem("You are {{personality}}."), + ) + + // Default: no source option resolves the prompt named after the agent. + chef := DefinePromptAgent[testState](reg, "chef", + WithDescription[testState]("a chef agent")) + if chef.Name() != "chef" { + t.Fatalf("chef.Name() = %q, want %q", chef.Name(), "chef") + } + if got := chef.Desc().Description; got != "a chef agent" { + t.Errorf("chef description = %q, want %q", got, "a chef agent") + } + + // Override: source the shared prompt with a code-supplied input. The agent + // name ("pirate") is independent of the referenced prompt ("sharedChat"). + pirate := DefinePromptAgent[testState](reg, "pirate", + WithNamedPrompt[testState]("sharedChat", personaInput{Personality: "a pirate"})) + + if _, err := chef.RunText(ctx, "hi"); err != nil { + t.Fatalf("chef RunText: %v", err) + } + if _, err := pirate.RunText(ctx, "hi"); err != nil { + t.Fatalf("pirate RunText: %v", err) + } + + joined := strings.Join(renderedSystems, " | ") + if !strings.Contains(joined, "You are a chef.") { + t.Errorf("default same-named prompt not rendered; systems = %q", joined) + } + if !strings.Contains(joined, "You are a pirate.") { + t.Errorf("WithNamedPrompt input not rendered; systems = %q", joined) + } +} + +// TestDefinePromptAgent_Panics covers the definition-time failures: a prompt +// that is not registered, and setting the prompt source more than once. +func TestDefinePromptAgent_Panics(t *testing.T) { + reg := setupPromptTestRegistry(t) + ai.DefinePrompt(reg, "present", ai.WithModelName("test/echo")) + + t.Run("missing prompt", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for missing prompt") + } + }() + DefinePromptAgent[testState](reg, "absent") + }) + + t.Run("duplicate prompt source", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for duplicate WithNamedPrompt") + } + }() + DefinePromptAgent[testState](reg, "present", + WithNamedPrompt[testState]("present", nil), + WithNamedPrompt[testState]("present", nil)) + }) +} + func TestPromptAgent_Basic(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) @@ -1594,7 +1689,7 @@ func TestPromptAgent_Basic(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineAgent[testState](reg, "testPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "testPrompt") conn, err := af.Connect(ctx) if err != nil { @@ -1669,7 +1764,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { ai.WithSystem("system prompt"), ) - af := DefineAgent[testState](reg, "historyPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "historyPrompt") conn, err := af.Connect(ctx) if err != nil { @@ -1743,7 +1838,7 @@ func TestPromptAgent_SnapshotResumePreservesHistory(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineAgent[testState](reg, "snapPrompt", SameNamedPrompt(), + af := DefinePromptAgent[testState](reg, "snapPrompt", WithSessionStore(store), ) @@ -1868,7 +1963,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { ai.WithTools(ai.ToolName("greet"), ai.ToolName("farewell")), ) - af := DefineAgent[testState](reg, "toolPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "toolPrompt") conn, err := af.Connect(ctx) if err != nil { @@ -2073,7 +2168,7 @@ func TestPromptAgent_RunText(t *testing.T) { ai.WithSystem("You are a test assistant."), ) - af := DefineAgent[testState](reg, "runTextPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "runTextPrompt") response, err := af.RunText(ctx, "hello") if err != nil { @@ -2098,7 +2193,7 @@ func TestPromptAgent_RejectsInvalidInputMessage(t *testing.T) { ctx := context.Background() reg := setupPromptTestRegistry(t) ai.DefinePrompt(reg, "rejectPrompt", ai.WithModelName("test/echo")) - af := DefineAgent[testState](reg, "rejectPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "rejectPrompt") tests := []struct { name string @@ -2257,7 +2352,7 @@ func TestPromptAgent_RejectsResumeForUnrequestedTool(t *testing.T) { ai.DefineGenerateAction(ctx, reg) ai.DefinePrompt(reg, "plainPrompt", ai.WithModelName("test/plain")) - af := DefineAgent[testState](reg, "plainPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "plainPrompt") conn, err := af.Connect(ctx) if err != nil { @@ -5111,7 +5206,7 @@ func TestPromptAgent_ForwardsFinishReason(t *testing.T) { ai.DefineGenerateAction(ctx, reg) ai.DefinePrompt(reg, "lengthPrompt", ai.WithModelName("test/length")) - af := DefineAgent[testState](reg, "lengthPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "lengthPrompt") conn, err := af.Connect(ctx) if err != nil { @@ -5514,7 +5609,7 @@ func TestPromptAgent_ForwardsInterruptedFinishReason(t *testing.T) { ai.WithTools(interruptTool), ) - af := DefineAgent[testState](reg, "interruptPrompt", SameNamedPrompt()) + af := DefinePromptAgent[testState](reg, "interruptPrompt") conn, err := af.Connect(ctx) if err != nil { @@ -6482,10 +6577,10 @@ func TestPromptAgent_InlineMessages_DoesNotMutateSharedMetadata(t *testing.T) { shared := ai.NewModelTextMessage("inline context message") shared.Metadata = map[string]any{"origin": "config"} - af := DefineAgent[testState](reg, "inlineMetaPrompt", InlinePrompt( + af := DefineAgent[testState](reg, "inlineMetaPrompt", InlinePrompt{ ai.WithModelName("test/echo"), ai.WithMessages(shared), - )) + }) response, err := af.RunText(ctx, "hello") if err != nil { @@ -6515,10 +6610,10 @@ func TestPromptAgent_InlineMessages_ConcurrentInvocations(t *testing.T) { shared := ai.NewModelTextMessage("inline context message") shared.Metadata = map[string]any{"origin": "config"} - af := DefineAgent[testState](reg, "inlineConcurrentPrompt", InlinePrompt( + af := DefineAgent[testState](reg, "inlineConcurrentPrompt", InlinePrompt{ ai.WithModelName("test/echo"), ai.WithMessages(shared), - )) + }) var wg sync.WaitGroup errs := make(chan error, 8) diff --git a/go/ai/exp/inline.go b/go/ai/exp/inline.go new file mode 100644 index 0000000000..9caceb5fb9 --- /dev/null +++ b/go/ai/exp/inline.go @@ -0,0 +1,37 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import "github.com/firebase/genkit/go/ai" + +// InlinePrompt is an inline prompt definition for an agent: the list of +// [ai.PromptOption] values that configure the agent's prompt. Pass one to +// [DefineAgent], which registers the prompt under the agent's name: +// +// agent := DefineAgent(r, "pirate", +// InlinePrompt{ +// ai.WithModelName("googleai/gemini-flash-latest"), +// ai.WithSystem("You are a sarcastic pirate."), +// }, +// WithSessionStore(store), +// ) +// +// To give the template a default render input, include [ai.WithInputType] among +// the options. For an agent backed by a prompt already in the registry (e.g. +// one defined via [ai.DefinePrompt] or loaded from a .prompt file), use +// [DefinePromptAgent] instead, which takes no InlinePrompt. +type InlinePrompt []ai.PromptOption diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 1ed5264129..9c6fb18719 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -24,10 +24,25 @@ import ( // --- AgentOption --- // AgentOption configures an agent at definition time. It is accepted by -// [DefineAgent] and [DefineCustomAgent] as a typed variadic, so a State -// mismatch fails at compile time. +// [DefineAgent], [DefineCustomAgent], and [DefinePromptAgent] as a typed +// variadic, so a State mismatch fails at compile time. +// +// Every AgentOption is also a [PromptAgentOption]: the shared options +// ([WithSessionStore], [WithStateTransform], [WithDescription]) configure all +// three constructors. The converse does not hold. A prompt-source option such +// as [WithNamedPrompt] is a [PromptAgentOption] but not an AgentOption, so +// passing it to [DefineAgent] or [DefineCustomAgent] is a compile-time error. type AgentOption[State any] interface { applyAgent(*agentOptions[State]) error + applyPromptAgent(*promptAgentOptions[State]) error +} + +// PromptAgentOption configures a prompt-backed agent defined via +// [DefinePromptAgent]. It is the wider set that additionally admits the +// prompt-source option [WithNamedPrompt]; every [AgentOption] satisfies it, so +// the shared agent options compose with [WithNamedPrompt] in a single variadic. +type PromptAgentOption[State any] interface { + applyPromptAgent(*promptAgentOptions[State]) error } // StateTransform rewrites session state on its way out to a client: it is @@ -70,6 +85,43 @@ func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { return nil } +// applyPromptAgent lets a shared agent option also configure a prompt-backed +// agent: it merges the base fields into the prompt-agent accumulator's embedded +// agentOptions. Implementing it is what makes every [AgentOption] usable +// wherever a [PromptAgentOption] is expected. +func (o *agentOptions[State]) applyPromptAgent(opts *promptAgentOptions[State]) error { + return o.applyAgent(&opts.agentOptions) +} + +// promptAgentOptions accumulates a [DefinePromptAgent] configuration: the +// shared agent options plus the optional prompt source. It doubles as the +// option value returned by [WithNamedPrompt], mirroring how ai/option.go reuses +// its accumulator structs as the options that fill them. +type promptAgentOptions[State any] struct { + agentOptions[State] + promptName string // referenced prompt; "" resolves to the agent's own name + promptInput any // render input for the referenced prompt + promptSet bool // WithNamedPrompt was used (guards against duplicates) +} + +// applyPromptAgent merges the base agent options and the prompt source. It +// shadows the promoted [agentOptions.applyPromptAgent] so a value carrying a +// prompt source contributes it in addition to the base fields. +func (o *promptAgentOptions[State]) applyPromptAgent(opts *promptAgentOptions[State]) error { + if err := o.agentOptions.applyAgent(&opts.agentOptions); err != nil { + return err + } + if o.promptSet { + if opts.promptSet { + return errors.New("cannot set prompt source more than once (WithNamedPrompt)") + } + opts.promptName = o.promptName + opts.promptInput = o.promptInput + opts.promptSet = true + } + return nil +} + // WithSessionStore sets the store for persisting snapshots. The store must // implement [SnapshotReader] and [SnapshotWriter] at minimum. Detach // support also requires [SnapshotSubscriber]; detach attempts on a store @@ -91,6 +143,27 @@ func WithDescription[State any](description string) AgentOption[State] { return &agentOptions[State]{description: description} } +// WithNamedPrompt points a [DefinePromptAgent] at the prompt registered under +// name, rendered with input on every turn (pass nil for the prompt's own +// default input). name need not match the agent's name, so a single registered +// prompt can back many agents with different inputs; pass "" to keep the +// default same-named lookup while still supplying a custom input. +// +// Without this option a prompt agent uses the prompt registered under its own +// name. This option lets a single registered prompt back many agents, each +// rendered with its own input, and composes with the other agent options in one +// variadic. +// +// input is rendered through the prompt once at definition time as a smoke +// check, so an input that fails the prompt's schema panics there rather than on +// the first invocation. +// +// This option applies only to [DefinePromptAgent]. Passing it to [DefineAgent] +// or [DefineCustomAgent] is a compile-time error. +func WithNamedPrompt[State any](name string, input any) PromptAgentOption[State] { + return &promptAgentOptions[State]{promptName: name, promptInput: input, promptSet: true} +} + // --- InvocationOption --- // InvocationOption configures an agent invocation (Connect, Run, or RunText). diff --git a/go/ai/exp/source.go b/go/ai/exp/source.go deleted file mode 100644 index e6afb53127..0000000000 --- a/go/ai/exp/source.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package exp - -import "github.com/firebase/genkit/go/ai" - -// AgentSource selects the prompt backing a prompt-based agent. Pass an -// AgentSource as the third argument to [DefineAgent]. There are three -// forms: -// -// - [InlinePrompt] defines the prompt inline from a set of -// [ai.PromptOption] values; the prompt is registered with the -// registry under the agent's name. -// - [SameNamedPrompt] references the prompt already registered under the -// agent's own name (e.g. one defined via [ai.DefinePrompt] or loaded -// from a .prompt file). -// - [NamedPrompt] references any registered prompt by name and renders it -// with an input supplied from code, so a single prompt can back many -// agents with different inputs. -// -// For full control over the per-turn loop, define a custom agent via -// [DefineCustomAgent] instead. -type AgentSource interface { - isAgentSource() -} - -type inlineSource struct { - opts []ai.PromptOption -} - -func (inlineSource) isAgentSource() {} - -// InlinePrompt defines the agent's prompt inline from the given options. -// The prompt is registered with the registry under the agent's name. To -// give the template a default render input, include [ai.WithInputType] -// among the options. -func InlinePrompt(opts ...ai.PromptOption) AgentSource { - return inlineSource{opts: opts} -} - -type existingSource struct { - name string // "" => resolve by the agent's own name - input any -} - -func (existingSource) isAgentSource() {} - -// SameNamedPrompt references the prompt registered under the agent's own -// name (e.g. one defined via [ai.DefinePrompt] or loaded from a .prompt -// file). The prompt renders with its own default input each turn. It is -// shorthand for NamedPrompt(, nil). -func SameNamedPrompt() AgentSource { - return existingSource{} -} - -// NamedPrompt references the prompt registered under name, rendered with -// input on every turn (pass nil for the prompt's own default input). name -// need not match the agent's name, so a single prompt can back many agents -// with different inputs. -// -// input is rendered through the prompt once at definition time as a smoke -// check, so an input that fails the prompt's schema panics there rather -// than on the first invocation. -func NamedPrompt(name string, input any) AgentSource { - return existingSource{name: name, input: input} -} diff --git a/go/genkit/exp/routes_test.go b/go/genkit/exp/routes_test.go index c649b0fa11..b67c930126 100644 --- a/go/genkit/exp/routes_test.go +++ b/go/genkit/exp/routes_test.go @@ -65,10 +65,10 @@ func newRouteTestGenkit(t *testing.T) *genkit.Genkit { if err != nil { t.Fatal(err) } - genkit.DefineAgent(g, "serverChat", aix.InlinePrompt(ai.WithModelName("test/echo")), + genkit.DefineAgent(g, "serverChat", aix.InlinePrompt{ai.WithModelName("test/echo")}, aix.WithSessionStore(store), ) - genkit.DefineAgent[any](g, "clientChat", aix.InlinePrompt(ai.WithModelName("test/echo"))) + genkit.DefineAgent[any](g, "clientChat", aix.InlinePrompt{ai.WithModelName("test/echo")}) genkit.DefineFlow(g, "greet", func(ctx context.Context, name string) (string, error) { return "hi " + name, nil }) @@ -112,8 +112,8 @@ func TestAgentRoutes_PicksOneAgentAndMirrorsCapabilities(t *testing.T) { if err != nil { t.Fatal(err) } - server := genkit.DefineAgent(g, "srv", aix.InlinePrompt(ai.WithModelName("test/echo")), aix.WithSessionStore(store)) - client := genkit.DefineAgent[any](g, "cli", aix.InlinePrompt(ai.WithModelName("test/echo"))) + server := genkit.DefineAgent(g, "srv", aix.InlinePrompt{ai.WithModelName("test/echo")}, aix.WithSessionStore(store)) + client := genkit.DefineAgent[any](g, "cli", aix.InlinePrompt{ai.WithModelName("test/echo")}) if got, want := routeKeys(AgentRoutes(server)), []string{ "POST /agents/srv", diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 2ae9b642ca..ad9b59db0b 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -406,8 +406,8 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In return core.NewStreamingFlow(name, fn) } -// DefineAgent defines a prompt-backed agent and registers it as an -// action on the registry. Returns an [aix.Agent]. +// DefineAgent defines an agent backed by an inline prompt and registers it as +// an action on the registry. Returns an [aix.Agent]. // // Experimental: This API is under active development and may change in any // minor version release. @@ -418,15 +418,9 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // handles session state, conversation history, and optional snapshot // persistence automatically. // -// source selects how the prompt is backed: -// -// - [aix.InlinePrompt] defines the prompt inline from a set of -// [ai.PromptOption] values; the prompt is registered under name. -// - [aix.SameNamedPrompt] references an existing prompt registered under -// name (e.g. one defined via [DefinePrompt] or loaded from a .prompt -// file). -// - [aix.NamedPrompt] references any registered prompt by name with an -// input supplied from code, so a single prompt can back many agents. +// The prompt is defined inline via [aix.InlinePrompt] and registered under the +// agent's name. To back the agent with a prompt already in the registry (e.g. +// one from a .prompt file), use [DefinePromptAgent] instead. // // The State type parameter is inferred from the typed agent options // (e.g. [aix.WithSessionStore], [aix.WithStateTransform]); pass an explicit @@ -445,33 +439,69 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // - [aix.WithSessionStore]: Enable snapshot persistence // - [aix.WithStateTransform]: Rewrite session state on its way out to the client // -// Example (inline prompt): +// Example: // // chatAgent := genkit.DefineAgent(g, "chat", -// aix.InlinePrompt( +// aix.InlinePrompt{ // ai.WithModelName("googleai/gemini-flash-latest"), // ai.WithSystem("You are a helpful assistant."), -// ), +// }, // aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), // ) +func DefineAgent[State any]( + g *Genkit, + name string, + prompt aix.InlinePrompt, + opts ...aix.AgentOption[State], +) *aix.Agent[State] { + return aix.DefineAgent(g.reg, name, prompt, opts...) +} + +// DefinePromptAgent defines a prompt-backed agent sourced from the registry by +// name and registers it as an action. Returns an [aix.Agent]. // -// Example (a shared .prompt file, parameterized per agent): +// Experimental: This API is under active development and may change in any +// minor version release. // -// type ChatInput struct { -// Personality string `json:"personality"` -// } +// By default the agent uses the prompt registered under its own name (e.g. one +// defined via [DefinePrompt] or loaded from a .prompt file), so no source +// option is required. Pass [aix.WithNamedPrompt] to reference a differently +// named prompt and supply its render input from code, so a single prompt can +// back many agents. // -// pirate := genkit.DefineAgent(g, "pirate", -// aix.NamedPrompt("chat", ChatInput{Personality: "a sarcastic pirate"}), +// It is the registry-backed counterpart of [DefineAgent]: where [DefineAgent] +// defines the prompt inline, DefinePromptAgent points at a prompt already in +// the registry. The prompt source is a typed option ([aix.WithNamedPrompt]) +// rather than a positional argument, so it composes with the other agent +// options in one variadic. For full control over the per-turn loop, use +// [DefineCustomAgent]. +// +// The State type parameter is inferred from the typed agent options; pass an +// explicit [State] only when no typed option provides it. +// +// # Options +// +// - [aix.WithNamedPrompt]: Source from a differently named prompt with a code-supplied input +// - [aix.WithSessionStore]: Enable snapshot persistence +// - [aix.WithStateTransform]: Rewrite session state on its way out to the client +// +// Example (same-named prompt loaded from ./prompts/chef.prompt): +// +// chef := genkit.DefinePromptAgent(g, "chef", // aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), // ) -func DefineAgent[State any]( +// +// Example (a shared prompt, parameterized per agent): +// +// pirate := genkit.DefinePromptAgent(g, "pirate", +// aix.WithNamedPrompt[any]("chat", ChatInput{Personality: "a sarcastic pirate"}), +// ) +func DefinePromptAgent[State any]( g *Genkit, name string, - source aix.AgentSource, - opts ...aix.AgentOption[State], + opts ...aix.PromptAgentOption[State], ) *aix.Agent[State] { - return aix.DefineAgent(g.reg, name, source, opts...) + return aix.DefinePromptAgent(g.reg, name, opts...) } // DefineCustomAgent defines an agent with full control over the conversation @@ -490,9 +520,8 @@ func DefineAgent[State any]( // via [Handler], with companion actions on [aix.Agent.GetSnapshotAction] // and [aix.Agent.AbortSnapshotAction]. // -// For agents backed by a prompt, use [DefineAgent] with [aix.InlinePrompt] -// (inline prompt), [aix.SameNamedPrompt], or [aix.NamedPrompt] (existing -// prompt) instead. +// For agents backed by a prompt, use [DefineAgent] (inline prompt) or +// [DefinePromptAgent] (a prompt already in the registry) instead. // // # Options // diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index e6fa1e288a..49816c6d6c 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -758,13 +758,13 @@ func TestHandlerAgent(t *testing.T) { }, nil }) - DefineAgent[any](g, "agentClient", aix.InlinePrompt(ai.WithModelName("test/echo"))) + DefineAgent[any](g, "agentClient", aix.InlinePrompt{ai.WithModelName("test/echo")}) store, err := localstore.NewFileSessionStore[any](t.TempDir()) if err != nil { t.Fatal(err) } - DefineAgent(g, "agentServer", aix.InlinePrompt(ai.WithModelName("test/echo")), + DefineAgent(g, "agentServer", aix.InlinePrompt{ai.WithModelName("test/echo")}, aix.WithSessionStore(store), ) @@ -969,7 +969,7 @@ func TestHandlerAgentRef(t *testing.T) { if err != nil { t.Fatal(err) } - agent := DefineAgent(g, "agentRef", aix.InlinePrompt(ai.WithModelName("test/echo")), + agent := DefineAgent(g, "agentRef", aix.InlinePrompt{ai.WithModelName("test/echo")}, aix.WithSessionStore(store), ) diff --git a/go/samples/basic-agents-server/main.go b/go/samples/basic-agents-server/main.go index 6099b3f4d6..7575fb5806 100644 --- a/go/samples/basic-agents-server/main.go +++ b/go/samples/basic-agents-server/main.go @@ -135,10 +135,10 @@ func main() { log.Fatalf("creating session store: %v", err) } genkit.DefineAgent(g, "chat", - aix.InlinePrompt( + aix.InlinePrompt{ ai.WithModel(model), ai.WithSystem("You are a helpful travel assistant. Keep responses to a couple of sentences."), - ), + }, aix.WithSessionStore(store), ) @@ -146,10 +146,10 @@ func main() { // the full conversation state and the client round-trips it on the next // request. This suits deployments where the server must stay stateless. genkit.DefineAgent[any](g, "statelessChat", - aix.InlinePrompt( + aix.InlinePrompt{ ai.WithModel(model), ai.WithSystem("You are a helpful travel assistant. Keep responses to a couple of sentences."), - ), + }, ) // genkitx.AllAgentRoutes lays out a default HTTP surface for every diff --git a/go/samples/basic-agents/main.go b/go/samples/basic-agents/main.go index 55b3d5d8cd..5aec9dc7ab 100644 --- a/go/samples/basic-agents/main.go +++ b/go/samples/basic-agents/main.go @@ -17,8 +17,9 @@ // // - "pirate" uses DefineAgent + aix.InlinePrompt. The prompt is declared // inline next to the agent. -// - "chef" uses DefineAgent + aix.SameNamedPrompt. The prompt is loaded -// from ./prompts/chef.prompt by the agent's name. +// - "chef" uses DefinePromptAgent. With no source option it defaults to the +// prompt registered under the agent's name, loaded from +// ./prompts/chef.prompt. // - "coder" uses DefineCustomAgent. The per-turn loop (model selection, // history management, streaming) is wired by hand. // @@ -106,34 +107,34 @@ func main() { func defineInlineAgent(g *genkit.Genkit) *aix.Agent[any] { const name = "pirate" return genkit.DefineAgent(g, name, - aix.InlinePrompt( + aix.InlinePrompt{ ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, })), ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), - ), + }, aix.WithSessionStore(mustStore(name)), aix.WithDescription[any]("Sarcastic pirate (inline-defined prompt)"), ) } -// definePromptAgent demonstrates DefineAgent with aix.SameNamedPrompt. The -// prompt is loaded from ./prompts/.prompt by genkit's prompt -// registry. Defining the prompt in a file lets you tune model, config, -// schema, template, and default input independently of the Go code, which -// is useful when prompt authors are not the people writing the agent wiring. +// definePromptAgent demonstrates DefinePromptAgent. The prompt is loaded from +// ./prompts/.prompt by genkit's prompt registry. Defining the +// prompt in a file lets you tune model, config, schema, template, and default +// input independently of the Go code, which is useful when prompt authors are +// not the people writing the agent wiring. // -// SameNamedPrompt references the prompt registered under the agent's own -// name and renders it with the prompt's own default input each turn (here, -// the personality set in chef.prompt's frontmatter). To supply an input -// from code, or to back several agents with one shared prompt, use -// aix.NamedPrompt(name, input) instead. +// With no source option, DefinePromptAgent defaults to the prompt registered +// under the agent's own name and renders it with the prompt's own default +// input each turn (here, the personality set in chef.prompt's frontmatter). +// The prompt source is a typed option, so it sits in the same variadic as the +// other agent options. To supply an input from code, or to back several agents +// with one shared prompt, add aix.WithNamedPrompt(name, input). func definePromptAgent(g *genkit.Genkit) *aix.Agent[any] { const name = "chef" - return genkit.DefineAgent(g, name, - aix.SameNamedPrompt(), + return genkit.DefinePromptAgent(g, name, aix.WithSessionStore(mustStore(name)), aix.WithDescription[any]("Michelin-starred chef (prompt loaded from ./prompts/chef.prompt)"), ) From ad5d0cded79d89d18f9548abaa1285bb7185f768 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 23 Jun 2026 14:32:41 -0700 Subject: [PATCH 132/141] docs(go): match agents README to the DefineAgent/DefinePromptAgent split The "Load the Prompt from a File" path now uses DefinePromptAgent (default same-named lookup) with WithNamedPrompt for shared prompts, and the inline example uses the InlinePrompt slice literal. --- go/README.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/go/README.md b/go/README.md index ea0b1f1425..271dac989f 100644 --- a/go/README.md +++ b/go/README.md @@ -96,10 +96,10 @@ import ( ) chatAgent := genkit.DefineAgent(g, "chat", - aix.InlinePrompt( + aix.InlinePrompt{ ai.WithModelName("googleai/gemini-flash-latest"), ai.WithSystem("You are a sarcastic pirate. Keep responses concise."), - ), + }, aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), ) @@ -143,7 +143,7 @@ fmt.Println(out.Message.Text()) ### Load the Prompt from a File -`aix.SameNamedPrompt` backs the agent with the prompt registered under the agent's name, including one loaded from a `.prompt` file. Prompt authors can tune the model, config, template, and default input without touching the Go wiring: +`genkit.DefinePromptAgent` backs the agent with a prompt from the registry instead of an inline one. By default it uses the prompt registered under the agent's own name, including one loaded from a `.prompt` file, so prompt authors can tune the model, config, template, and default input without touching the Go wiring: ```yaml # prompts/chat.prompt @@ -166,22 +166,21 @@ type ChatInput struct { // Register the schema so the .prompt file can reference it by name. genkit.DefineSchemaFor[ChatInput](g) -// Agent "chat" renders ./prompts/chat.prompt every turn. -chatAgent := genkit.DefineAgent(g, "chat", - aix.SameNamedPrompt(), +// Agent "chat" renders ./prompts/chat.prompt every turn (no source option needed). +chatAgent := genkit.DefinePromptAgent(g, "chat", aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), ) ``` -To back several agents with one shared prompt, reference it by name with `aix.NamedPrompt` and give each its own input. The prompt name need not match the agent name: +To back several agents with one shared prompt, point each at it with `aix.WithNamedPrompt` and give each its own input. The prompt name need not match the agent name: ```go for _, p := range []struct{ name, persona string }{ {"pirate", "a sarcastic pirate"}, {"chef", "a Michelin-starred chef"}, } { - genkit.DefineAgent(g, p.name, - aix.NamedPrompt("chat", ChatInput{Personality: p.persona}), + genkit.DefinePromptAgent(g, p.name, + aix.WithNamedPrompt[any]("chat", ChatInput{Personality: p.persona}), aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), ) } From 06e2b5af24c26dd442c5087d8db3b6a0552c8d34 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 23 Jun 2026 15:57:47 -0700 Subject: [PATCH 133/141] feat(go/ai/exp): record session state and snapshot ID on agent turn spans Each runTurn-N span now records the committed session state at turn end as its genkit:output, shaped as {state: }, and for server-managed agents carries the turn-end snapshot's ID under genkit:metadata:agent:snapshotId. The session ID stays on the root action span as before. The state is raw: StateTransform shapes only client-facing surfaces, not telemetry or persisted state, so the span output matches the snapshot its ID points to. With the turn span output now derived from session state, the per-turn chunk collection is gone: removed SessionRunner.collectTurnOutput, chunkRouter.collectTurnChunks and its turnChunks/turnMu fields, and the accumulation branch in applySideEffects (now artifact-only). --- go/ai/exp/agent.go | 113 ++++++++++++------------ go/ai/exp/agent_test.go | 188 ++++++++++++++++++++++++++-------------- 2 files changed, 181 insertions(+), 120 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 004a3fdcc9..8353afe950 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -85,9 +85,8 @@ type SessionRunner[State any] struct { // incremented by Run after each turn completes. turnIndex int - onStartTurn func() - onEndTurn func(ctx context.Context) - collectTurnOutput func() any + onStartTurn func() + onEndTurn func(ctx context.Context) // snapMu serializes the turn-end snapshot write (snapshotTurnEnd) // against the detach handler's suspend-and-capture (suspendSnapshots). @@ -160,6 +159,16 @@ type TurnResult struct { FinishReason AgentFinishReason } +// turnSpanOutput is the value recorded as a turn span's genkit:output. It +// wraps the committed session state captured at turn end under a "state" key, +// so the span output serializes as {"state": }. The state is +// raw: a configured [StateTransform] shapes only client-facing surfaces, not +// telemetry or persisted state, so this matches what a server-managed turn +// writes to its turn-end snapshot. +type turnSpanOutput[State any] struct { + State *SessionState[State] `json:"state"` +} + // Run loops over the input channel, calling fn for each turn. Each turn is // wrapped in a trace span for observability. Input messages are automatically // added to the session before fn is called. After fn returns successfully, a @@ -214,10 +223,9 @@ func (s *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont reason = tr.FinishReason } s.endTurn(ctx, reason, false) - if s.collectTurnOutput != nil { - return s.collectTurnOutput(), nil - } - return nil, nil + // The turn span's output is the committed session state at + // turn end, recorded as {state: ...} (see turnSpanOutput). + return turnSpanOutput[State]{State: s.State()}, nil }, ) if err != nil { @@ -363,8 +371,8 @@ func (s *SessionRunner[State]) snapshotTurnEnd(ctx context.Context, finishReason type Responder struct { in chan<- *AgentStreamChunk ctx context.Context - // effects applies the chunk's in-process side effects (session - // artifact add, turn-chunk accumulation) synchronously in send, in + // effects applies the chunk's in-process side effects (adding an + // artifact chunk's artifact to the session) synchronously in send, in // the sender's goroutine, so reads and snapshots that follow a Send // cannot miss the chunk. effects func(*AgentStreamChunk) @@ -862,12 +870,17 @@ type fnDoneResult[State any] struct { err error } -// sessionIDSpanAttrKey is the full span-attribute key under which an agent's -// root action span records its session ID. It is the "genkit:metadata:"-prefixed -// form of the "agent:sessionId" custom-metadata key the JS agent sets via -// setCustomMetadataAttributes; the prefix is inlined here because Go's tracing -// package exposes no setCustomMetadataAttributes helper. -const sessionIDSpanAttrKey = "genkit:metadata:agent:sessionId" +// sessionIDSpanAttrKey and snapshotIDSpanAttrKey are the full span-attribute +// keys under which an agent records its identifiers: the session ID on the +// root action span, and the turn-end snapshot ID on each server-managed turn +// span. They are the "genkit:metadata:"-prefixed forms of the +// "agent:sessionId" / "agent:snapshotId" custom-metadata keys the JS agent +// sets via setCustomMetadataAttributes; the prefix is inlined here because +// Go's tracing package exposes no setCustomMetadataAttributes helper. +const ( + sessionIDSpanAttrKey = "genkit:metadata:agent:sessionId" + snapshotIDSpanAttrKey = "genkit:metadata:agent:snapshotId" +) func newAgentRuntime[State any]( ctx context.Context, @@ -943,7 +956,6 @@ func newAgentRuntime[State any]( // make it the resume point a first-turn failure falls back to. rt.sess.lastSnapshotID = parent.SnapshotID } - rt.sess.collectTurnOutput = func() any { return rt.router.collectTurnChunks() } rt.sess.onEndTurn = rt.emitTurnEnd // Stream custom-state mutations as customPatch chunks. beginTurn is armed // per turn by the runner; the session's onCustomChange hook is wired in @@ -983,6 +995,15 @@ func (rt *agentRuntime[State]) emitTurnEnd(ctx context.Context) { if !rt.sess.lastTurnFailed { snapshotID = rt.sess.snapshotTurnEnd(ctx, reason) } + // Tag the turn span with the snapshot it persisted, so a server-managed + // turn's trace links to its snapshot. ctx is the turn span's context (this + // runs inside the runTurn-N span via onEndTurn). The ID is empty, and the + // attribute omitted, when client-managed, when the turn failed, or when a + // detach suspended snapshots. + if snapshotID != "" { + trace.SpanFromContext(ctx).SetAttributes( + attribute.String(snapshotIDSpanAttrKey, snapshotID)) + } rt.router.sendChunk(ctx, &AgentStreamChunk{TurnEnd: &TurnEnd{ SnapshotID: snapshotID, FinishReason: reason, @@ -1003,9 +1024,9 @@ func (rt *agentRuntime[State]) run( // Wire custom-state streaming now that the work context exists: every // UpdateCustom mutation during the invocation emits a customPatch chunk - // through the same responder fn uses (so the chunk is accumulated for the - // turn span and forwarded on the wire, dropping post-detach like any - // other chunk). The session mutation itself still applies regardless. + // through the same responder fn uses (so the chunk is forwarded on the + // wire, dropping post-detach like any other chunk). The session mutation + // itself still applies regardless. resp := rt.router.responder(workCtx) rt.patcher.bind(workCtx, resp.send) rt.session.onCustomChange = rt.patcher.onChange @@ -1588,13 +1609,13 @@ func resumeSessionFrom[State any](s *Session[State], snap *SessionSnapshot[State // --- chunkRouter --- // // chunkRouter owns the intermediate stream channel that all chunks flow -// through on their way to outCh. A chunk's in-process side effects -// (adding artifacts to the session, accumulating turn chunks for span -// output) are applied synchronously by Responder.send before the chunk -// enters the router, so every chunk gets them in its sender's goroutine -// regardless of whether detach has landed; the router owns only the wire -// forward to outCh, which is the one thing detach suppresses, since the -// bidi framework closes outCh shortly after bidiFn returns. The router +// through on their way to outCh. A chunk's in-process side effect (adding +// an artifact chunk's artifact to the session) is applied synchronously by +// Responder.send before the chunk enters the router, so every chunk gets it +// in its sender's goroutine regardless of whether detach has landed; the +// router owns only the wire forward to outCh, which is the one thing detach +// suppresses, since the bidi framework closes outCh shortly after bidiFn +// returns. The router // commits to not writing before we return so that close is safe, and // keeps draining its input so the user fn never blocks on a responder // send. @@ -1605,9 +1626,6 @@ type chunkRouter[State any] struct { out chan<- *AgentStreamChunk session *Session[State] - turnMu sync.Mutex - turnChunks []*AgentStreamChunk - done chan struct{} stopWriting chan struct{} writerStopped chan struct{} @@ -1646,23 +1664,18 @@ func (r *chunkRouter[State]) run() { } } -// applySideEffects records the chunk's effect on session state and turn -// span output. Invoked synchronously from Responder.send, in the -// sender's goroutine, so the effects are ordered before everything the -// sender does after Send: a state read, a turn-end snapshot, or -// [SessionRunner.Result] immediately after SendArtifact observes the -// artifact. The artifact is deep-copied on its way into the session so -// the sender's retained pointer (which also rides the wire chunk) cannot -// alias live session state. +// applySideEffects records the chunk's effect on session state: an artifact +// chunk adds its artifact to the session. Invoked synchronously from +// Responder.send, in the sender's goroutine, so the effect is ordered before +// everything the sender does after Send: a state read, a turn-end snapshot, or +// [SessionRunner.Result] immediately after SendArtifact observes the artifact. +// The artifact is deep-copied on its way into the session so the sender's +// retained pointer (which also rides the wire chunk) cannot alias live session +// state. func (r *chunkRouter[State]) applySideEffects(chunk *AgentStreamChunk) { if chunk.Artifact != nil { r.session.AddArtifacts(jsonClone(chunk.Artifact)) } - if chunk.TurnEnd == nil { - r.turnMu.Lock() - r.turnChunks = append(r.turnChunks, chunk) - r.turnMu.Unlock() - } } // forward delivers chunks to outCh until told to stop writing, the @@ -1703,9 +1716,8 @@ func (r *chunkRouter[State]) responder(ctx context.Context) Responder { // sendChunk delivers chunk to the router for producers other than the // user agent function (e.g. the runtime's emitTurnEnd). It skips the // in-process side effects (the only runtime-produced chunk is TurnEnd, -// which has none: no artifact, and TurnEnd is excluded from turn-chunk -// accumulation) and returns promptly if ctx is cancelled, dropping the -// chunk. +// which has none: no artifact) and returns promptly if ctx is cancelled, +// dropping the chunk. func (r *chunkRouter[State]) sendChunk(ctx context.Context, chunk *AgentStreamChunk) { select { case r.in <- chunk: @@ -1713,15 +1725,6 @@ func (r *chunkRouter[State]) sendChunk(ctx context.Context, chunk *AgentStreamCh } } -// collectTurnChunks returns and resets accumulated turn chunks. -func (r *chunkRouter[State]) collectTurnChunks() []*AgentStreamChunk { - r.turnMu.Lock() - defer r.turnMu.Unlock() - result := r.turnChunks - r.turnChunks = nil - return result -} - // stopAndWait tells the router to stop writing to out and blocks until it // has committed. After it returns, it is safe for the framework to close // out without risking a write-to-closed-channel panic. @@ -1754,7 +1757,7 @@ type customPatcher[State any] struct { session *Session[State] ctx context.Context // invocation work context, for the transform - send func(*AgentStreamChunk) // forwards the chunk (accumulate + wire) + send func(*AgentStreamChunk) // forwards the chunk (side effects + wire) mu sync.Mutex firstInTurn bool diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index 57f0984ec5..efe22f5e85 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -148,7 +148,7 @@ func TestAgent_BasicMultiTurn(t *testing.T) { // per-turn whole-document replace (the first patch of a turn re-bases the // client) followed by an incremental diff within the same turn, and that the // tracking carries across turns. The server-side patch emission is covered by -// TestAgent_TurnSpanOutput_WithSnapshots; this is its client-side complement. +// TestAgent_CustomPatchWholeDocumentReplace; this is its client-side complement. func TestAgentConnection_Custom_TracksStreamedPatches(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) @@ -1345,22 +1345,36 @@ func TestAgent_SetMessages(t *testing.T) { } } +// turnSpanState parses a turn span's genkit:output attribute, which the agent +// records as {"state": } (see turnSpanOutput), and returns the +// embedded state. +func turnSpanState(t *testing.T, span sdktrace.ReadOnlySpan) *SessionState[testState] { + t.Helper() + raw, ok := spanAttr(span, "genkit:output") + if !ok { + t.Fatalf("span %q: missing genkit:output attribute", span.Name()) + } + var out turnSpanOutput[testState] + if err := json.Unmarshal([]byte(raw), &out); err != nil { + t.Fatalf("span %q: genkit:output %q is not valid turn output: %v", span.Name(), raw, err) + } + if out.State == nil { + t.Fatalf("span %q: genkit:output %q carries no state", span.Name(), raw) + } + return out.State +} + +// TestAgent_TurnSpanOutput verifies each turn span's output is the committed +// session state at turn end, wrapped as {state: ...}. The agent is +// client-managed (no store), so no turn-end snapshot is written and the +// per-turn snapshot-ID attribute is absent. func TestAgent_TurnSpanOutput(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) - - var capturedOutputs []any + spans := collectSpans(t) af := DefineCustomAgent(reg, "turnOutputFlow", func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - // Wrap collectTurnOutput to capture what each turn produces. - originalCollect := sess.collectTurnOutput - sess.collectTurnOutput = func() any { - output := originalCollect() - capturedOutputs = append(capturedOutputs, output) - return output - } - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { sess.UpdateCustom(func(s testState) testState { s.Counter++ @@ -1390,59 +1404,51 @@ func TestAgent_TurnSpanOutput(t *testing.T) { } conn.Close() - if _, err := conn.Output(); err != nil { + out, err := conn.Output() + if err != nil { t.Fatalf("Output failed: %v", err) } - // Should have captured output for each turn. - if len(capturedOutputs) != 2 { - t.Fatalf("expected 2 captured outputs, got %d", len(capturedOutputs)) - } - - for i, output := range capturedOutputs { - chunks, ok := output.([]*AgentStreamChunk) - if !ok { - t.Fatalf("turn %d: expected []*AgentStreamChunk, got %T", i, output) + // Each turn span's output carries the cumulative state through that turn: + // the running counter and one user message plus one reply per turn. + for turn := range 2 { + name := fmt.Sprintf("runTurn-%d", turn+1) + span := spans.byName(name) + if span == nil { + t.Fatalf("missing span %q", name) } - // 3 content chunks per turn: customPatch + model chunk + artifact. - if len(chunks) != 3 { - t.Errorf("turn %d: expected 3 chunks, got %d", i, len(chunks)) + state := turnSpanState(t, span) + if state.SessionID != out.SessionID { + t.Errorf("%s: state.sessionId = %q, want %q", name, state.SessionID, out.SessionID) } - for j, chunk := range chunks { - if chunk.TurnEnd != nil { - t.Errorf("turn %d, chunk %d: TurnEnd should not be in turn output", i, j) - } + if want := turn + 1; state.Custom.Counter != want { + t.Errorf("%s: state.custom.counter = %d, want %d", name, state.Custom.Counter, want) + } + if got, want := len(state.Messages), 2*(turn+1); got != want { + t.Errorf("%s: len(state.messages) = %d, want %d", name, got, want) + } + // An artifact added during the turn rides the committed state too. + if got := len(state.Artifacts); got != 1 { + t.Errorf("%s: len(state.artifacts) = %d, want 1", name, got) + } + // Client-managed: no snapshot, so no snapshot-ID attribute. + if v, ok := spanAttr(span, snapshotIDSpanAttrKey); ok { + t.Errorf("%s: unexpected %s = %q (client-managed writes no snapshot)", name, snapshotIDSpanAttrKey, v) } } } +// TestAgent_TurnSpanOutput_WithSnapshots verifies that for a server-managed +// agent the turn span both carries the persisted snapshot's ID under +// genkit:metadata:agent:snapshotId and records the committed state as its +// {state: ...} output, agreeing with the snapshot the ID points to. func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { ctx := context.Background() reg := newTestRegistry(t) store := newTestInMemStore[testState]() + spans := collectSpans(t) - var capturedOutputs []any - - af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", - func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { - originalCollect := sess.collectTurnOutput - sess.collectTurnOutput = func() any { - output := originalCollect() - capturedOutputs = append(capturedOutputs, output) - return output - } - - return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { - sess.UpdateCustom(func(s testState) testState { - s.Counter++ - return s - }) - sess.AddMessages(ai.NewModelTextMessage("reply")) - return nil, nil - }) - }, - WithSessionStore(store), - ) + af := defineCounterAgent(reg, "turnOutputSnapshotFlow", WithSessionStore(store)) conn, err := af.Connect(ctx) if err != nil { @@ -1450,28 +1456,80 @@ func TestAgent_TurnSpanOutput_WithSnapshots(t *testing.T) { } sendText(t, conn, "hello") - var sawSnapshot bool - if nextTurnEnd(t, conn).SnapshotID != "" { - sawSnapshot = true + te := nextTurnEnd(t, conn) + if te.SnapshotID == "" { + t.Fatal("expected a snapshot ID on the turn-end chunk") } conn.Close() - conn.Output() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output failed: %v", err) + } - if !sawSnapshot { - t.Fatal("expected a snapshot ID on the turn-end chunk") + span := spans.byName("runTurn-1") + if span == nil { + t.Fatal("missing span runTurn-1") + } + + // The turn span is tagged with the turn-end snapshot's ID. + got, ok := spanAttr(span, snapshotIDSpanAttrKey) + if !ok { + t.Fatalf("turn span: missing %s", snapshotIDSpanAttrKey) + } + if got != te.SnapshotID { + t.Errorf("turn span %s = %q, want %q (the turn-end snapshot)", snapshotIDSpanAttrKey, got, te.SnapshotID) } - // Turn output should contain only the customPatch chunk, not the TurnEnd signal. - if len(capturedOutputs) != 1 { - t.Fatalf("expected 1 captured output, got %d", len(capturedOutputs)) + // Its output is the committed state, matching the persisted snapshot. + state := turnSpanState(t, span) + if state.Custom.Counter != 1 { + t.Errorf("turn span state.custom.counter = %d, want 1", state.Custom.Counter) + } + snap, err := store.GetSnapshot(ctx, te.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot: %v", err) + } + if state.SessionID != snap.SessionID { + t.Errorf("turn span state.sessionId = %q, want %q (snapshot's session)", state.SessionID, snap.SessionID) } - chunks := capturedOutputs[0].([]*AgentStreamChunk) - if len(chunks) != 1 { - t.Errorf("expected 1 content chunk, got %d", len(chunks)) + if got, want := state.Custom.Counter, snap.State.Custom.Counter; got != want { + t.Errorf("turn span state.custom.counter = %d, want %d (snapshot's counter)", got, want) } - // The first (and only) patch of the turn is a whole-document replace. - if got := chunks[0].CustomPatch; len(got) != 1 || got[0].Op != JSONPatchOpReplace || got[0].Path != "" { - t.Errorf("expected a whole-document replace customPatch, got %+v", got) +} + +// TestAgent_CustomPatchWholeDocumentReplace verifies the server emits the first +// custom-state mutation of a turn as a whole-document replace: a single RFC 6902 +// replace at the root pointer, which re-bases a client that may not share the +// server's baseline. The client-side effect of this re-basing across turns is +// covered by TestAgentConnection_Custom_TracksStreamedPatches. +func TestAgent_CustomPatchWholeDocumentReplace(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := defineCounterAgent(reg, "wholeDocReplaceFlow") + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + sendText(t, conn, "hello") + + var patch JSONPatch + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.CustomPatch != nil { + patch = chunk.CustomPatch + break + } + if chunk.TurnEnd != nil { + break + } + } + conn.Close() + + if len(patch) != 1 || patch[0].Op != JSONPatchOpReplace || patch[0].Path != "" { + t.Errorf("first customPatch = %+v, want a single whole-document replace at root", patch) } } From 82c36216a43e20308bf8ba2f300a4f199c54e23c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 23 Jun 2026 17:09:52 -0700 Subject: [PATCH 134/141] feat(go/ai/exp): add WithStreamTransform; transforms fail closed on error WithStreamTransform is the stream-side counterpart to WithStateTransform, rewriting each AgentStreamChunk on its way to the client. Both StateTransform and StreamTransform now return (value, error): a nil value omits the state or drops the chunk (wire-only), while a non-nil error (or a panic) fails the read or invocation closed with the transform's status preserved. Updates go/README.md and the genkit.DefineAgent option docs. --- go/README.md | 16 ++ go/ai/exp/agent.go | 221 ++++++++++++++++-- go/ai/exp/agent_test.go | 111 ++++++++- go/ai/exp/custompatch_test.go | 53 ++++- go/ai/exp/option.go | 90 +++++++- go/ai/exp/session.go | 16 +- go/ai/exp/streamtransform_test.go | 367 ++++++++++++++++++++++++++++++ go/genkit/genkit.go | 3 + 8 files changed, 835 insertions(+), 42 deletions(-) create mode 100644 go/ai/exp/streamtransform_test.go diff --git a/go/README.md b/go/README.md index 271dac989f..2f248d56e5 100644 --- a/go/README.md +++ b/go/README.md @@ -253,6 +253,22 @@ Resume from one specific point in history with `aix.WithSnapshotID`, or skip the [See full example](samples/basic-agents) +### Redact on the Way Out + +`WithStateTransform` rewrites session state as it leaves the server, on `GetSnapshot` reads, on a client-managed `out.State`, and on the streamed `CustomPatch` diffs, while the persisted snapshot and the state your agent function sees stay raw: + +```go +chatAgent := genkit.DefineAgent(g, "chat", + aix.InlinePrompt{ai.WithModelName("googleai/gemini-flash-latest")}, + aix.WithSessionStore(store), + aix.WithStateTransform[ChatState](func(ctx context.Context, s *aix.SessionState[ChatState]) (*aix.SessionState[ChatState], error) { + return redactPII(ctx, s) // ctx carries caller identity for RBAC-aware redaction + }), +) +``` + +`WithStreamTransform[State]` is the stream-side counterpart, rewriting each `AgentStreamChunk` (model tokens, artifacts, custom patches, turn-end) on its way to the client. Both transforms own a fresh deep copy: mutate it in place, return a new value, or return `nil` to omit that state (or drop that chunk) from the client's view. A non-nil error fails closed, so the read or invocation fails with the transform's status (e.g. `PERMISSION_DENIED`) instead of leaking unredacted data. + ### Background Agents `Detach` hands the rest of the work to the server and closes the connection promptly with a pending snapshot ID. The agent keeps processing in the background on a context decoupled from the client's, so a long task survives the caller walking away: diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index 8353afe950..93fcfc362b 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -668,8 +668,9 @@ func DefineAgent[State any]( // defines the prompt inline, DefinePromptAgent points at a prompt already in // the registry. The prompt source is a typed option ([WithNamedPrompt]) rather // than a positional argument, so it composes with the other agent options -// ([WithSessionStore], [WithStateTransform], [WithDescription]) in a single -// variadic. For full control over the per-turn loop, use [DefineCustomAgent]. +// ([WithSessionStore], [WithStateTransform], [WithStreamTransform], +// [WithDescription]) in a single variadic. For full control over the per-turn +// loop, use [DefineCustomAgent]. // // State is inferred from the typed agent options; pass an explicit [State] only // when no typed option provides it (e.g. only [WithNamedPrompt] and @@ -861,6 +862,47 @@ type agentRuntime[State any] struct { intake *detachIntake fnDone chan fnDoneResult[State] + // fatalErr latches the first fail-closed error from a streaming transform + // (the stream transform in the router, or the state transform behind a + // custom-state patch). Buffered to one and written non-blocking, so the + // producer never blocks and only the first error wins; the run loop drains + // it to resolve the invocation as a failed output. See failTransform. + fatalErr chan error +} + +// failTransform records a fail-closed error from a streaming transform without +// blocking the producer that hit it. The buffered, non-blocking send keeps the +// first error and discards the rest; the run loop observes it (directly via its +// select arm, or after the fact via handleFnDone) and resolves the invocation +// as a failed output. Safe to call from the router and the fn goroutines. +func (rt *agentRuntime[State]) failTransform(err error) { + select { + case rt.fatalErr <- err: + default: // a fatal error is already latched; first one wins + } +} + +// takeFatal returns the latched streaming-transform error, or nil if none. +// Non-blocking, so a terminal path can fold a fatal error that raced fn's +// completion into a failed output. +func (rt *agentRuntime[State]) takeFatal() error { + select { + case err := <-rt.fatalErr: + return err + default: + return nil + } +} + +// panicError logs a recovered panic with its stack and returns it as an +// INTERNAL error; what names the code that panicked (e.g. "agent fn"). Call it +// from a deferred recover, where the stack still reaches the panic site. It is +// the shared shape of the runtime's two recover sites: the agent fn and the +// stream transform, both of which contain a panic in user code rather than let +// it crash the process. +func panicError(ctx context.Context, what string, rec any) error { + logger.FromContext(ctx).Error(what+" panicked", "panic", rec, "stack", string(debug.Stack())) + return core.NewError(core.INTERNAL, "%s panicked: %v", what, rec) } // fnDoneResult carries the user fn's return values across the goroutine @@ -939,13 +981,16 @@ func newAgentRuntime[State any]( attribute.String(sessionIDSpanAttrKey, session.state.SessionID)) rt := &agentRuntime[State]{ - name: name, - cfg: cfg, - session: session, - router: startChunkRouter(ctx, session, outCh), - intake: startDetachIntake(inCh), - fnDone: make(chan fnDoneResult[State], 1), + name: name, + cfg: cfg, + session: session, + intake: startDetachIntake(inCh), + fnDone: make(chan fnDoneResult[State], 1), + fatalErr: make(chan error, 1), } + // Started after rt exists so the router can signal a fail-closed stream + // transform error back through rt.failTransform. + rt.router = startChunkRouter(ctx, session, outCh, cfg.streamTransform, rt.failTransform) rt.sess = &SessionRunner[State]{ Session: session, @@ -964,6 +1009,7 @@ func newAgentRuntime[State any]( transform: cfg.transform, session: session, firstInTurn: true, + fail: rt.failTransform, } rt.sess.onStartTurn = rt.patcher.beginTurn // The initial state (fresh, client-provided, or loaded from a snapshot) @@ -1065,8 +1111,7 @@ func (rt *agentRuntime[State]) run( func() { defer func() { if r := recover(); r != nil { - logger.FromContext(workCtx).Error("agent fn panicked", "panic", r, "stack", string(debug.Stack())) - fnErr = core.NewError(core.INTERNAL, "agent fn panicked: %v", r) + fnErr = panicError(workCtx, "agent fn", r) } }() result, fnErr = fn(workCtx, resp, rt.sess) @@ -1085,6 +1130,9 @@ func (rt *agentRuntime[State]) run( case res := <-rt.fnDone: return rt.handleFnDone(clientCtx, cancelWork, res) + case cause := <-rt.fatalErr: + return rt.handleTransformFailure(clientCtx, cancelWork, cause) + case <-clientCtx.Done(): res := rt.drainAndWait(cancelWork) if res.err != nil { @@ -1094,6 +1142,31 @@ func (rt *agentRuntime[State]) run( } } +// handleTransformFailure is the fail-closed terminal path for a streaming +// transform that returned an error (or panicked): the stream transform in the +// router, or the state transform behind a custom-state patch. It tears the +// invocation down like a fn error and resolves it as a failed output carrying +// the transform's cause, so no unshaped chunk reaches the client and the +// offending chunk's side effects never surface in a completed output. +// +// drainAndWait cancels the work context (stopping fn), switches the router to +// discard mode, and drains fn; the router has typically already stopped writing +// the moment shape returned the error, but a custom-patch failure trips this +// path while the router is still forwarding, so the stop here is what halts it. +func (rt *agentRuntime[State]) handleTransformFailure( + clientCtx context.Context, + cancelWork context.CancelFunc, + cause error, +) (*AgentOutput[State], error) { + rt.drainAndWait(cancelWork) + // A disconnect that raced the failure keeps error semantics: there is no + // client to hand a graceful failed output to (mirrors handleFnDone). + if clientCtx.Err() != nil { + return nil, cause + } + return rt.failedOutput(clientCtx, cause), nil +} + // checkDetachCapabilities reports whether the configured store is capable // of supporting detach. Detach requires a writable store (to persist the // pending snapshot, and to abort it and refresh its heartbeat via ordinary @@ -1153,10 +1226,30 @@ func (rt *agentRuntime[State]) handleFnDone( ) (*AgentOutput[State], error) { cancelWork() rt.intake.stopAndWait() - if res.err != nil { + // A custom-state patch whose transform failed closed latches during fn, so + // it is readable now; a failed turn likewise wants its in-flight chunks + // dropped. Either way stop router writes before close so it cannot wedge + // behind a slow or gone consumer. A stream-transform failure instead puts + // the router into discard mode the instant it occurs (forward never parks), + // so it needs no stop here and is picked up after close below. + fatal := rt.takeFatal() + if res.err != nil || fatal != nil { rt.router.stopAndWait() } rt.router.close() + if fatal == nil { + fatal = rt.takeFatal() + } + + // A streaming transform that failed closed resolves the invocation as + // failed regardless of what fn returned, so no completed output leaks the + // data it refused to shape. + if fatal != nil { + if ctx.Err() != nil { + return nil, fatal + } + return rt.failedOutput(ctx, fatal), nil + } if res.err != nil { // A disconnect-driven failure keeps its error semantics: the @@ -1188,7 +1281,17 @@ func (rt *agentRuntime[State]) handleFnDone( out.Artifacts = cloneArtifacts(res.result.Artifacts) } if rt.cfg.store == nil { - out.State = rt.outboundState(ctx, rt.session.State()) + // A final-output state transform that fails closed turns the otherwise + // successful invocation into a failed output, so unshaped state is + // never handed back. + state, err := rt.outboundState(ctx, rt.session.State()) + if err != nil { + if ctx.Err() != nil { + return nil, err + } + return rt.failedOutput(ctx, err), nil + } + out.State = state } return out, nil } @@ -1196,13 +1299,17 @@ func (rt *agentRuntime[State]) handleFnDone( // outboundState applies the configured state transform and re-stamps the // framework-owned SessionID, so the state handed to a client-managed // caller always carries the conversation's identity even if a transform -// rewrote or dropped it. Returns nil if state is nil. -func (rt *agentRuntime[State]) outboundState(ctx context.Context, state *SessionState[State]) *SessionState[State] { - out := applyTransform(ctx, rt.cfg.transform, state) +// rewrote or dropped it. Returns (nil, nil) if state is nil, and a non-nil +// error if the transform failed closed. +func (rt *agentRuntime[State]) outboundState(ctx context.Context, state *SessionState[State]) (*SessionState[State], error) { + out, err := applyTransform(ctx, rt.cfg.transform, state) + if err != nil { + return nil, err + } if out != nil { out.SessionID = rt.session.SessionID() } - return out + return out, nil } // failedOutput assembles the output for an invocation that ended in @@ -1222,7 +1329,17 @@ func (rt *agentRuntime[State]) failedOutput(ctx context.Context, cause error) *A Error: core.AsGenkitError(cause), } if rt.cfg.store == nil { - out.State = rt.outboundState(ctx, rt.sess.lastGoodState) + // This is already the failure path, so a transform that also fails + // closed while shaping the last-good state cannot escalate further: + // omit state (fail closed, no leak) rather than recurse. The original + // cause is what the caller needs and is preserved on Error above. + if state, err := rt.outboundState(ctx, rt.sess.lastGoodState); err != nil { + logger.FromContext(ctx).Error( + "agent state transform failed shaping failed-output state; omitting state", + "error", err) + } else { + out.State = state + } } else { out.SnapshotID = rt.sess.lastSnapshotID } @@ -1625,6 +1742,13 @@ type chunkRouter[State any] struct { in chan *AgentStreamChunk out chan<- *AgentStreamChunk session *Session[State] + // transform shapes each chunk on the wire; see [WithStreamTransform]. Nil + // forwards chunks verbatim. + transform StreamTransform + // fail reports a fail-closed transform error (or panic) to the runtime so + // the invocation resolves as a failed output. Nil only when no transform is + // configured, since that is the only thing that can fail here. + fail func(error) done chan struct{} stopWriting chan struct{} @@ -1635,12 +1759,16 @@ func startChunkRouter[State any]( ctx context.Context, session *Session[State], out chan<- *AgentStreamChunk, + transform StreamTransform, + fail func(error), ) *chunkRouter[State] { r := &chunkRouter[State]{ ctx: ctx, in: make(chan *AgentStreamChunk), out: out, session: session, + transform: transform, + fail: fail, done: make(chan struct{}), stopWriting: make(chan struct{}), writerStopped: make(chan struct{}), @@ -1688,6 +1816,22 @@ func (r *chunkRouter[State]) forward() bool { if !ok { return false } + shaped, err := r.shape(chunk) + if err != nil { + // The stream transform failed closed (returned an error or + // panicked). Report it so the invocation resolves as a failed + // output, and switch to discard mode so no further chunk + // reaches the wire: fail-closed means stop forwarding entirely. + r.fail(err) + return true + } + if shaped == nil { + // The stream transform dropped the chunk from the wire. Its + // side effects already applied at Send time, so there is + // nothing else to do; carry on draining the next chunk. + continue + } + chunk = shaped select { case r.out <- chunk: case <-r.stopWriting: @@ -1705,6 +1849,34 @@ func (r *chunkRouter[State]) forward() bool { } } +// shape applies the configured stream transform to chunk, returning the chunk +// to forward on the wire, nil to drop it, or a non-nil error to fail the +// invocation closed; with no transform it returns chunk unchanged. The +// transform receives a fresh deep copy it owns, so mutating it in place cannot +// disturb the chunk's already-applied side effects (an artifact recorded on the +// session) or any pointer the sender retained. r.ctx is the action context, +// which carries the caller's identity for RBAC-aware redaction; the transform +// only runs on chunks bound for a live client, since forward stops calling it +// once writes cease. +// +// The transform is user code running in the router's own goroutine, which +// nothing else recovers (unlike the agent fn and the state transform, whose +// goroutines are covered), so a panic here would crash the process rather than +// fail just the invocation. Contain it the way the fn path does (log with a +// stack) and surface it as a fail-closed error, the same outcome as an explicit +// error return: the invocation fails rather than leaking the unshaped chunk. +func (r *chunkRouter[State]) shape(chunk *AgentStreamChunk) (out *AgentStreamChunk, err error) { + if r.transform == nil { + return chunk, nil + } + defer func() { + if rec := recover(); rec != nil { + out, err = nil, panicError(r.ctx, "agent stream transform", rec) + } + }() + return r.transform(r.ctx, jsonClone(chunk)) +} + // responder returns a [Responder] that applies chunk side effects // synchronously and sends chunks into the router for the wire forward. // The returned Responder's Send methods drop the forward (returning @@ -1755,6 +1927,10 @@ func (r *chunkRouter[State]) close() { type customPatcher[State any] struct { transform StateTransform[State] session *Session[State] + // fail reports a fail-closed transform error to the runtime so the + // invocation resolves as a failed output rather than streaming a delta + // derived from state the transform refused to shape. + fail func(error) ctx context.Context // invocation work context, for the transform send func(*AgentStreamChunk) // forwards the chunk (side effects + wire) @@ -1802,8 +1978,17 @@ func (p *customPatcher[State]) onChange() { if p.transform == nil { next = p.session.customJSON() } else { + t, err := applyTransform(p.ctx, p.transform, p.session.State()) + if err != nil { + // The state transform failed closed while shaping the streamed + // custom delta. Withhold the patch and fail the invocation; the run + // loop tears it down as a failed output, the same fail-closed + // outcome as a stream-transform error in the router. + p.fail(err) + return + } var custom any - if t := applyTransform(p.ctx, p.transform, p.session.State()); t != nil { + if t != nil { custom = t.Custom } next = normalizeJSON(custom) diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go index efe22f5e85..6b435af0ee 100644 --- a/go/ai/exp/agent_test.go +++ b/go/ai/exp/agent_test.go @@ -3655,7 +3655,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { // Transform that scrubs a specific word from all messages. It also // (incorrectly) drops the framework-owned session ID, which the // action must re-stamp on the way out. - transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { + transform := func(_ context.Context, s *SessionState[testState]) (*SessionState[testState], error) { for _, msg := range s.Messages { for _, p := range msg.Content { if p.Text != "" { @@ -3664,7 +3664,7 @@ func TestAgent_GetSnapshotAction_ReturnsTransformedState(t *testing.T) { } } s.SessionID = "" - return s + return s, nil } af := DefineCustomAgent(reg, "transformedFlow", @@ -3760,7 +3760,7 @@ func TestAgent_GetSnapshot_FacadeTransformsRawStoreDoesNot(t *testing.T) { reg := newTestRegistry(t) store := newTestInMemStore[testState]() - transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { + transform := func(_ context.Context, s *SessionState[testState]) (*SessionState[testState], error) { for _, msg := range s.Messages { for _, p := range msg.Content { if p.Text != "" { @@ -3768,7 +3768,7 @@ func TestAgent_GetSnapshot_FacadeTransformsRawStoreDoesNot(t *testing.T) { } } } - return s + return s, nil } af := DefineCustomAgent(reg, "facadeTransform", @@ -4655,10 +4655,10 @@ func TestAgent_StateTransform_ClientManagedState(t *testing.T) { reg := newTestRegistry(t) // Client-managed state: transform should be applied to AgentOutput.State. - transform := func(_ context.Context, s *SessionState[testState]) *SessionState[testState] { + transform := func(_ context.Context, s *SessionState[testState]) (*SessionState[testState], error) { // Zero out the counter to demonstrate the transform is applied. s.Custom.Counter = -1 - return s + return s, nil } af := DefineCustomAgent(reg, "clientXformFlow", @@ -4686,6 +4686,105 @@ func TestAgent_StateTransform_ClientManagedState(t *testing.T) { } } +// TestAgent_StateTransform_ErrorFailsClientManagedOutputClosed verifies a state +// transform that returns an error fails the invocation closed: rather than +// handing back unshaped state, the otherwise-successful client-managed +// invocation resolves as a failed output, with the transform's status preserved +// and no state attached. +func TestAgent_StateTransform_ErrorFailsClientManagedOutputClosed(t *testing.T) { + reg := newTestRegistry(t) + + transform := func(_ context.Context, s *SessionState[testState]) (*SessionState[testState], error) { + return nil, core.NewError(core.PERMISSION_DENIED, "cannot shape state") + } + + af := DefineCustomAgent(reg, "clientXformErr", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("done")) + return nil, nil + }) + }, + WithStateTransform[testState](transform), + ) + + out, err := af.RunText(context.Background(), "go") + if err != nil { + t.Fatalf("expected a graceful failed output, got error: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonFailed) + } + if out.Error == nil || out.Error.Status != core.PERMISSION_DENIED { + t.Errorf("Error = %+v, want status %q from the transform", out.Error, core.PERMISSION_DENIED) + } + // failedOutput shapes the last-good state through the same transform, which + // errors again here; the runtime omits state rather than leaking it. + if out.State != nil { + t.Errorf("expected no state on the failed output (transform failed closed), got %+v", out.State) + } +} + +// TestAgent_StateTransform_ErrorFailsSnapshotReadClosed verifies a state +// transform that errors fails a snapshot read closed: both the typed +// Agent.GetSnapshot facade and the getSnapshot companion action surface the +// transform's error (status preserved) instead of returning unshaped state. +func TestAgent_StateTransform_ErrorFailsSnapshotReadClosed(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + transform := func(_ context.Context, s *SessionState[testState]) (*SessionState[testState], error) { + return nil, core.NewError(core.PERMISSION_DENIED, "cannot shape state") + } + + af := DefineCustomAgent(reg, "snapXformErr", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("done")) + return nil, nil + }) + }, + WithSessionStore(store), + WithStateTransform[testState](transform), + ) + + // A successful run persists a snapshot (the run itself does not read state + // back through the transform, so it is unaffected). + out, err := af.RunText(ctx, "go") + if err != nil { + t.Fatalf("RunText: %v", err) + } + if out.SnapshotID == "" { + t.Fatal("expected a persisted snapshot ID") + } + + // The typed facade fails closed with the transform's status. + if _, err := af.GetSnapshot(ctx, out.SnapshotID); err == nil { + t.Error("Agent.GetSnapshot: expected the transform error, got nil") + } else if core.AsGenkitError(err).Status != core.PERMISSION_DENIED { + t.Errorf("Agent.GetSnapshot error status = %q, want %q", core.AsGenkitError(err).Status, core.PERMISSION_DENIED) + } + + // The getSnapshot companion action fails the same way for non-Go clients. + action := core.ResolveActionFor[*GetSnapshotRequest, *SessionSnapshot[testState], struct{}]( + reg, api.ActionTypeAgentSnapshot, "snapXformErr") + if action == nil { + t.Fatal("getSnapshot action not registered") + } + if _, err := action.Run(ctx, &GetSnapshotRequest{SnapshotID: out.SnapshotID}, nil); err == nil { + t.Error("getSnapshot action: expected the transform error, got nil") + } else if core.AsGenkitError(err).Status != core.PERMISSION_DENIED { + t.Errorf("getSnapshot action error status = %q, want %q", core.AsGenkitError(err).Status, core.PERMISSION_DENIED) + } + + // The stored snapshot itself is untouched: the failure is read-time shaping, + // not corruption, so the row is still resumable. + if snap, err := store.GetSnapshot(ctx, out.SnapshotID); err != nil || snap == nil { + t.Errorf("stored GetSnapshot = (%v, %v), want the intact snapshot", snap, err) + } +} + func TestAgent_ResumeFromFinalizedDetachedSnapshot(t *testing.T) { // End-to-end: run a flow that the client detaches from, let it // finalize, then resume from its snapshot as if reconnecting later. diff --git a/go/ai/exp/custompatch_test.go b/go/ai/exp/custompatch_test.go index ced24c9f94..df1d923106 100644 --- a/go/ai/exp/custompatch_test.go +++ b/go/ai/exp/custompatch_test.go @@ -20,8 +20,10 @@ import ( "context" "sync" "testing" + "time" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" ) // collectTurnPatches consumes one turn's chunks, returning the customPatch from @@ -173,9 +175,9 @@ func TestCustomPatch_HonorsStateTransform(t *testing.T) { }) }, // Redact Topics on the way out to the client. - WithStateTransform(func(ctx context.Context, st *SessionState[testState]) *SessionState[testState] { + WithStateTransform(func(ctx context.Context, st *SessionState[testState]) (*SessionState[testState], error) { st.Custom.Topics = nil - return st + return st, nil }), ) @@ -200,6 +202,53 @@ func TestCustomPatch_HonorsStateTransform(t *testing.T) { } } +// TestCustomPatch_StateTransformErrorFailsInvocationClosed verifies a state +// transform that errors while shaping a streamed custom delta fails the +// invocation closed: the patch is withheld (no delta reaches the wire) and the +// invocation resolves as a failed output carrying the transform's status. +func TestCustomPatch_StateTransformErrorFailsInvocationClosed(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "cp", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.UpdateCustom(func(s testState) testState { + s.Counter = 5 + return s + }) + return nil, nil + }) + }, + WithStateTransform(func(_ context.Context, st *SessionState[testState]) (*SessionState[testState], error) { + return nil, core.NewError(core.PERMISSION_DENIED, "cannot shape custom state") + }), + ) + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + + conn.SendText("go") + // The transform errors before any delta is emitted, so no customPatch + // reaches the wire; the stream ends as the invocation fails. + if patches := collectTurnPatches(t, conn); len(patches) != 0 { + t.Errorf("customPatches on the wire = %d, want 0 (withheld, fail-closed)", len(patches)) + } + + out, err := outputWithin(t, conn, 2*time.Second) + if err != nil { + t.Fatalf("expected a graceful failed output, got error: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonFailed) + } + if out.Error == nil || out.Error.Status != core.PERMISSION_DENIED { + t.Errorf("Error = %+v, want status %q from the transform", out.Error, core.PERMISSION_DENIED) + } +} + // TestCustomPatch_ConcurrentMutations exercises the patcher's locking when // custom state is mutated from several goroutines at once: the streamed patches // must converge on the final state, and there must be no data race (run with diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index 9c6fb18719..f4038e1777 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -28,10 +28,11 @@ import ( // variadic, so a State mismatch fails at compile time. // // Every AgentOption is also a [PromptAgentOption]: the shared options -// ([WithSessionStore], [WithStateTransform], [WithDescription]) configure all -// three constructors. The converse does not hold. A prompt-source option such -// as [WithNamedPrompt] is a [PromptAgentOption] but not an AgentOption, so -// passing it to [DefineAgent] or [DefineCustomAgent] is a compile-time error. +// ([WithSessionStore], [WithStateTransform], [WithStreamTransform], +// [WithDescription]) configure all three constructors. The converse does not +// hold. A prompt-source option such as [WithNamedPrompt] is a +// [PromptAgentOption] but not an AgentOption, so passing it to [DefineAgent] or +// [DefineCustomAgent] is a compile-time error. type AgentOption[State any] interface { applyAgent(*agentOptions[State]) error applyPromptAgent(*promptAgentOptions[State]) error @@ -51,16 +52,65 @@ type PromptAgentOption[State any] interface { // the store or passed to the agent function. Typical uses are PII redaction and // stripping secrets. // -// state is a fresh deep copy the transform owns: it may mutate in place, return -// a new pointer, or return nil to omit state from the response. ctx is the -// request or invocation context, carrying deadlines and values such as the +// state is a fresh deep copy the transform owns. It returns the state to expose +// and a nil error, with two special outcomes: +// +// - A nil state omits state from the response. This is an intentional, +// successful outcome: the stored snapshot and the agent's own view keep the +// data, only the outbound copy is dropped. +// - A non-nil error fails closed: the operation being shaped aborts rather +// than exposing anything. A snapshot read returns the error; the final +// [AgentOutput] of an invocation resolves as a failed output carrying it. +// Use it when redaction cannot be performed safely (e.g. an RBAC policy +// lookup failed), where omitting (nil) would be a silent under-redaction +// and returning raw state would leak. The error's status code (such as +// PERMISSION_DENIED) propagates to the caller. +// +// ctx is the request or invocation context, carrying deadlines and values such +// as the caller's identity for RBAC-aware redaction. +type StateTransform[State any] = func(ctx context.Context, state *SessionState[State]) (*SessionState[State], error) + +// StreamTransform rewrites an [AgentStreamChunk] on its way out to a client: it +// runs at the wire boundary on every chunk the runtime forwards (model chunks, +// artifacts, custom-state patches, and turn-end signals), after the chunk's +// in-process side effects have already been applied. So, like [StateTransform], +// it shapes only what the client sees, not session state, persisted snapshots, +// or anything passed to the agent function. It also leaves the final +// [AgentOutput] untouched; it is a stream-only hook. Typical uses are redacting +// streamed model tokens and stripping fields a client should not receive. +// +// chunk is a fresh deep copy the transform owns. It returns the chunk to +// forward and a nil error, with two special outcomes: +// +// - A nil chunk drops it from the stream. This is an intentional, successful +// outcome and is wire-only: the chunk's side effects (an artifact recorded +// on the session) and the final [AgentOutput] keep the underlying data. +// Dropping a chunk whose [AgentStreamChunk.TurnEnd] is set withholds the +// turn-end signal a client uses to pace the conversation, so reshape such a +// chunk rather than dropping it. +// - A non-nil error fails closed: the whole invocation fails, so no unshaped +// chunk reaches the client and the offending chunk's side effects do not +// surface in the final output either. Use it when a chunk cannot be shaped +// safely; a panic in the transform is treated the same way. +// +// ctx is the request context, carrying deadlines and values such as the // caller's identity for RBAC-aware redaction. -type StateTransform[State any] = func(ctx context.Context, state *SessionState[State]) *SessionState[State] +// +// To redact custom state, prefer [WithStateTransform]: the runtime applies it +// before diffing, so the custom-patch stream stays internally consistent. +// Rewriting an [AgentStreamChunk.CustomPatch] delta here can desync a client +// that reconstructs custom state from the patch sequence. +// +// Unlike [StateTransform] it is not parameterized by State: a chunk carries no +// state type, so [WithStreamTransform] cannot infer State from the transform +// and takes it as an explicit type argument. +type StreamTransform = func(ctx context.Context, chunk *AgentStreamChunk) (*AgentStreamChunk, error) type agentOptions[State any] struct { - store SessionStore[State] - transform StateTransform[State] - description string + store SessionStore[State] + transform StateTransform[State] + streamTransform StreamTransform + description string } func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { @@ -76,6 +126,12 @@ func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { } opts.transform = o.transform } + if o.streamTransform != nil { + if opts.streamTransform != nil { + return errors.New("cannot set stream transform more than once (WithStreamTransform)") + } + opts.streamTransform = o.streamTransform + } if o.description != "" { if opts.description != "" { return errors.New("cannot set description more than once (WithDescription)") @@ -137,6 +193,18 @@ func WithStateTransform[State any](transform StateTransform[State]) AgentOption[ return &agentOptions[State]{transform: transform} } +// WithStreamTransform registers a [StreamTransform] applied to every +// [AgentStreamChunk] on its way out to a client. Typical uses are redacting +// streamed model tokens and stripping fields a client should not receive; it is +// the stream-side counterpart to [WithStateTransform]. +// +// A chunk is not parameterized by the state type, so State cannot be inferred +// from the transform and must be supplied explicitly to match the agent's, e.g. +// WithStreamTransform[MyState](fn). +func WithStreamTransform[State any](transform StreamTransform) AgentOption[State] { + return &agentOptions[State]{streamTransform: transform} +} + // WithDescription sets a human-readable description of the agent, stored on its // action descriptor (read back via [Agent.Desc] and shown in the Dev UI). func WithDescription[State any](description string) AgentOption[State] { diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index 4f8760eccd..facabdea7b 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -31,10 +31,11 @@ import ( // --- Snapshot --- // applyTransform returns the result of applying t to state, or state -// unchanged if t is nil. A nil state is returned as-is. -func applyTransform[State any](ctx context.Context, t StateTransform[State], state *SessionState[State]) *SessionState[State] { +// unchanged if t is nil. A nil state is returned as-is. A non-nil error from +// the transform is propagated so callers can fail closed. +func applyTransform[State any](ctx context.Context, t StateTransform[State], state *SessionState[State]) (*SessionState[State], error) { if t == nil || state == nil { - return state + return state, nil } return t(ctx, state) } @@ -263,8 +264,13 @@ func readSnapshot[State any]( // Clone before transforming: the [StateTransform] contract promises a fresh // deep copy the transform may mutate in place, and the store's row may share // memory with its internal copy, which neither the transform nor the SessionID - // re-stamp below may write into. - resp.State = applyTransform(ctx, transform, jsonClone(snap.State)) + // re-stamp below may write into. A transform error fails the read closed, + // with the transform's own status (e.g. PERMISSION_DENIED) preserved. + transformed, err := applyTransform(ctx, transform, jsonClone(snap.State)) + if err != nil { + return nil, err + } + resp.State = transformed if resp.State != nil { // SessionID is framework identity, not user data: re-stamp it from the // row after the transform so outbound state always agrees with the diff --git a/go/ai/exp/streamtransform_test.go b/go/ai/exp/streamtransform_test.go new file mode 100644 index 0000000000..19ef43373e --- /dev/null +++ b/go/ai/exp/streamtransform_test.go @@ -0,0 +1,367 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" +) + +// TestStreamTransform_RedactsModelChunksOnWire verifies a stream transform that +// edits a chunk in place reaches the wire: the streamed model text is redacted +// for the client. +func TestStreamTransform_RedactsModelChunksOnWire(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "st", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + resp.SendModelChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("the secret is 42")}, + }) + return nil, nil + }) + }, + WithStreamTransform[testState](func(_ context.Context, c *AgentStreamChunk) (*AgentStreamChunk, error) { + if c.ModelChunk != nil { + for _, p := range c.ModelChunk.Content { + p.Text = strings.ReplaceAll(p.Text, "secret", "[REDACTED]") + } + } + return c, nil + }), + ) + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer conn.Output() + + sendText(t, conn, "tell me") + var got string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.ModelChunk != nil { + got += chunk.ModelChunk.Text() + } + if chunk.TurnEnd != nil { + break + } + } + + if want := "the [REDACTED] is 42"; got != want { + t.Errorf("streamed model text = %q, want %q", got, want) + } +} + +// TestStreamTransform_NilDropsFromWireKeepsSideEffects verifies returning a nil +// chunk drops it from the wire while its already-applied side effects survive: +// an artifact dropped from the stream still lands in the session and the final +// output. nil is the wire-only, successful "omit" outcome, distinct from an +// error (fail-closed). +func TestStreamTransform_NilDropsFromWireKeepsSideEffects(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "st", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + resp.SendArtifact(&Artifact{ + Name: "secret.txt", + Parts: []*ai.Part{ai.NewTextPart("classified")}, + }) + return nil, nil + }) + if err != nil { + return nil, err + } + return &AgentResult{Artifacts: sess.Artifacts()}, nil + }, + // Drop every artifact chunk from the stream (intentional omit, not an error). + WithStreamTransform[testState](func(_ context.Context, c *AgentStreamChunk) (*AgentStreamChunk, error) { + if c.Artifact != nil { + return nil, nil + } + return c, nil + }), + ) + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + + sendText(t, conn, "go") + var wireArtifacts int + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.Artifact != nil { + wireArtifacts++ + } + if chunk.TurnEnd != nil { + break + } + } + if wireArtifacts != 0 { + t.Errorf("artifacts on the wire = %d, want 0 (dropped by transform)", wireArtifacts) + } + + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + // The invocation still succeeds, and the side effect (artifact recorded on + // the session) is untouched: dropping it from the stream is a wire-only edit. + if out.FinishReason == AgentFinishReasonFailed { + t.Fatalf("expected a successful invocation, got failed: %+v", out.Error) + } + if len(out.Artifacts) != 1 || out.Artifacts[0].Name != "secret.txt" { + t.Errorf("output artifacts = %+v, want one named secret.txt", out.Artifacts) + } +} + +// TestStreamTransform_OwnsDeepCopy verifies the transform receives a fresh deep +// copy it owns: mutating a chunk in place changes only the wire copy, leaving +// the session/output artifact and the pointer the agent fn retained untouched. +func TestStreamTransform_OwnsDeepCopy(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + // The artifact the fn sends and keeps a pointer to. The transform renames + // the wire copy; this pointer must not see that rename. + var sent *Artifact + + af := DefineCustomAgent(reg, "st", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sent = &Artifact{Name: "real.go", Parts: []*ai.Part{ai.NewTextPart("x")}} + resp.SendArtifact(sent) + return nil, nil + }) + if err != nil { + return nil, err + } + return &AgentResult{Artifacts: sess.Artifacts()}, nil + }, + WithStreamTransform[testState](func(_ context.Context, c *AgentStreamChunk) (*AgentStreamChunk, error) { + if c.Artifact != nil { + c.Artifact.Name = "renamed.go" + } + return c, nil + }), + ) + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + + sendText(t, conn, "go") + var wireName string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive: %v", err) + } + if chunk.Artifact != nil { + wireName = chunk.Artifact.Name + } + if chunk.TurnEnd != nil { + break + } + } + + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + + // The wire copy carries the transform's edit. + if wireName != "renamed.go" { + t.Errorf("wire artifact name = %q, want renamed.go", wireName) + } + // The final output (session side effect) keeps the original name. + if len(out.Artifacts) != 1 || out.Artifacts[0].Name != "real.go" { + t.Errorf("output artifacts = %+v, want one named real.go", out.Artifacts) + } + // Output() has drained the router, so the transform has run; the fn's + // retained pointer is still the original, proving the transform mutated a + // copy rather than the caller's artifact. + if sent.Name != "real.go" { + t.Errorf("agent fn's retained artifact name = %q, want real.go", sent.Name) + } +} + +// TestStreamTransform_ReshapesTurnEnd verifies the transform also sees control +// chunks: stripping the snapshot ID from a turn-end signal hides it from the +// client while the snapshot itself is still persisted and reported on the final +// output. +func TestStreamTransform_ReshapesTurnEnd(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := newTestInMemStore[testState]() + + af := DefineCustomAgent(reg, "st", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + sess.AddMessages(ai.NewModelTextMessage("done")) + return nil, nil + }) + }, + WithSessionStore(store), + // Hide the server-side snapshot ID from the streamed turn-end. + WithStreamTransform[testState](func(_ context.Context, c *AgentStreamChunk) (*AgentStreamChunk, error) { + if c.TurnEnd != nil { + c.TurnEnd.SnapshotID = "" + } + return c, nil + }), + ) + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + + sendText(t, conn, "go") + te := nextTurnEnd(t, conn) + if te.SnapshotID != "" { + t.Errorf("streamed TurnEnd.SnapshotID = %q, want empty (stripped on the wire)", te.SnapshotID) + } + + conn.Close() + out, err := conn.Output() + if err != nil { + t.Fatalf("Output: %v", err) + } + // The final output keeps the real snapshot ID (the transform is stream-only) + // and the snapshot is genuinely persisted. + if out.SnapshotID == "" { + t.Fatal("output SnapshotID is empty; expected the persisted snapshot ID") + } + if snap, err := store.GetSnapshot(ctx, out.SnapshotID); err != nil || snap == nil { + t.Errorf("GetSnapshot(%q) = (%v, %v), want the persisted snapshot", out.SnapshotID, snap, err) + } +} + +// TestStreamTransform_ErrorFailsInvocationClosed verifies a transform that +// returns an error fails the whole invocation closed: the offending chunk never +// reaches the wire, the invocation resolves as a failed output, and the +// transform's status code is preserved. +func TestStreamTransform_ErrorFailsInvocationClosed(t *testing.T) { + transform := func(_ context.Context, c *AgentStreamChunk) (*AgentStreamChunk, error) { + if c.ModelChunk != nil { + return nil, core.NewError(core.PERMISSION_DENIED, "cannot shape chunk") + } + return c, nil + } + assertStreamTransformFailsClosed(t, transform, core.PERMISSION_DENIED, "cannot shape chunk") +} + +// TestStreamTransform_PanicFailsInvocationClosed verifies a panicking transform +// is contained in the router goroutine (not a process crash) and treated as a +// fail-closed error: the invocation resolves as a failed output rather than +// leaking the chunk. +func TestStreamTransform_PanicFailsInvocationClosed(t *testing.T) { + transform := func(_ context.Context, c *AgentStreamChunk) (*AgentStreamChunk, error) { + if c.ModelChunk != nil { + panic("boom") + } + return c, nil + } + assertStreamTransformFailsClosed(t, transform, core.INTERNAL, "panicked") +} + +// assertStreamTransformFailsClosed runs an agent that streams one model chunk +// through transform and asserts the invocation fails closed: the chunk never +// reaches the wire and the output is a failure carrying wantStatus and a message +// containing wantMsg. +func assertStreamTransformFailsClosed(t *testing.T, transform StreamTransform, wantStatus core.StatusName, wantMsg string) { + t.Helper() + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "st", + func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentInput) (*TurnResult, error) { + resp.SendModelChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("leak")}, + }) + return nil, nil + }) + }, + WithStreamTransform[testState](transform), + ) + + conn, err := af.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + + sendText(t, conn, "go") + for chunk, err := range conn.Receive() { + if err != nil { + break // failure may surface as the stream ends; Output carries it + } + if chunk.ModelChunk != nil { + t.Error("model chunk reached the wire; expected fail-closed") + } + } + + out, err := outputWithin(t, conn, 2*time.Second) + if err != nil { + t.Fatalf("expected a graceful failed output, got error: %v", err) + } + if out.FinishReason != AgentFinishReasonFailed { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, AgentFinishReasonFailed) + } + if out.Error == nil || out.Error.Status != wantStatus { + t.Errorf("Error = %+v, want status %q", out.Error, wantStatus) + } + if out.Error != nil && !strings.Contains(out.Error.Message, wantMsg) { + t.Errorf("Error.Message = %q, want it to contain %q", out.Error.Message, wantMsg) + } +} + +// TestStreamTransform_RejectsSecondOption verifies WithStreamTransform, like the +// other agent options, may be set only once. +func TestStreamTransform_RejectsSecondOption(t *testing.T) { + reg := newTestRegistry(t) + noop := func(_ context.Context, c *AgentStreamChunk) (*AgentStreamChunk, error) { return c, nil } + noopFn := func(ctx context.Context, resp Responder, sess *SessionRunner[testState]) (*AgentResult, error) { + return nil, nil + } + defer func() { + if recover() == nil { + t.Error("expected a panic when WithStreamTransform is set twice") + } + }() + DefineCustomAgent(reg, "twice", noopFn, + WithStreamTransform[testState](noop), WithStreamTransform[testState](noop)) +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index ad9b59db0b..ad0f071393 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -438,6 +438,7 @@ func NewStreamingFlow[In, Out, Stream any](name string, fn core.StreamingFunc[In // // - [aix.WithSessionStore]: Enable snapshot persistence // - [aix.WithStateTransform]: Rewrite session state on its way out to the client +// - [aix.WithStreamTransform]: Rewrite stream chunks on their way out to the client // // Example: // @@ -484,6 +485,7 @@ func DefineAgent[State any]( // - [aix.WithNamedPrompt]: Source from a differently named prompt with a code-supplied input // - [aix.WithSessionStore]: Enable snapshot persistence // - [aix.WithStateTransform]: Rewrite session state on its way out to the client +// - [aix.WithStreamTransform]: Rewrite stream chunks on their way out to the client // // Example (same-named prompt loaded from ./prompts/chef.prompt): // @@ -527,6 +529,7 @@ func DefinePromptAgent[State any]( // // - [aix.WithSessionStore]: Enable snapshot persistence // - [aix.WithStateTransform]: Rewrite session state on its way out to the client +// - [aix.WithStreamTransform]: Rewrite stream chunks on their way out to the client // // The State type parameter is the shape of the conversation's custom state // ([aix.SessionState.Custom]); mutating it via [aix.Session.UpdateCustom] From 0abd6f5d3b17930ffef8c947f0a9c6b8c7e095fe Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 23 Jun 2026 18:45:45 -0700 Subject: [PATCH 135/141] feat(go/ai/exp): return snapshotId on AbortSnapshotResponse The abortSnapshot companion action's response carried only status, so a caller could not correlate the result with the snapshot it aborted. Add snapshotId (matching the abortSnapshot request) and populate it from the request. Authored in the shared zod schema (genkit-tools/common/src/types/agent.ts) and regenerated across genkit-schema.json, the Go bindings (go/ai/exp/gen.go), and the Python typings (_typing.py); the Go doc and noomitempty come from schemas.config. --- genkit-tools/common/src/types/agent.ts | 2 ++ genkit-tools/genkit-schema.json | 6 ++++++ go/ai/exp/gen.go | 2 ++ go/ai/exp/session.go | 2 +- go/core/schemas.config | 5 +++++ py/packages/genkit/src/genkit/_core/_typing.py | 1 + 6 files changed, 17 insertions(+), 1 deletion(-) diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts index 36def1d841..1b2c4100b3 100644 --- a/genkit-tools/common/src/types/agent.ts +++ b/genkit-tools/common/src/types/agent.ts @@ -445,6 +445,8 @@ export type AbortSnapshotRequest = z.infer; * Zod schema for the output of the `abortSnapshot` companion action. */ export const AbortSnapshotResponseSchema = z.object({ + /** Identifies the snapshot the abort attempt targeted. */ + snapshotId: z.string(), /** * Snapshot's status after the abort attempt. For a pending snapshot * this is `aborted`. For an already-terminal snapshot this is the diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 7a82bd85ad..7b114ba36f 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -16,10 +16,16 @@ "AbortSnapshotResponse": { "type": "object", "properties": { + "snapshotId": { + "type": "string" + }, "status": { "$ref": "#/$defs/SnapshotStatus" } }, + "required": [ + "snapshotId" + ], "additionalProperties": false }, "AgentFinishReason": { diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index 50d991fe41..fbe5555c3f 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -32,6 +32,8 @@ type AbortSnapshotRequest struct { // AbortSnapshotResponse is the output of the abortSnapshot companion action. type AbortSnapshotResponse struct { + // SnapshotID identifies the snapshot the abort attempt targeted. + SnapshotID string `json:"snapshotId"` // Status is the snapshot's status after the abort attempt. For a // pending snapshot this is [SnapshotStatusAborted]. For an // already-terminal snapshot this is the existing terminal status (the diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go index facabdea7b..cd01daa130 100644 --- a/go/ai/exp/session.go +++ b/go/ai/exp/session.go @@ -316,7 +316,7 @@ func newSnapshotActions[State any]( if status == "" { return nil, core.NewError(core.NOT_FOUND, "abortSnapshot: snapshot %q not found", req.SnapshotID) } - return &AbortSnapshotResponse{Status: status}, nil + return &AbortSnapshotResponse{SnapshotID: req.SnapshotID, Status: status}, nil }) return getSnapshotAction, abortSnapshotAction } diff --git a/go/core/schemas.config b/go/core/schemas.config index fcfd3fc19c..e839788315 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1828,6 +1828,11 @@ AbortSnapshotResponse doc AbortSnapshotResponse is the output of the abortSnapshot companion action. . +AbortSnapshotResponse.snapshotId noomitempty +AbortSnapshotResponse.snapshotId doc +SnapshotID identifies the snapshot the abort attempt targeted. +. + AbortSnapshotResponse.status doc Status is the snapshot's status after the abort attempt. For a pending snapshot this is [SnapshotStatusAborted]. For an diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index fa47e13c7a..d1bd4b0ee4 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -132,6 +132,7 @@ class AbortSnapshotResponse(GenkitModel): """Model for abortsnapshotresponse data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str = Field(...) status: SnapshotStatus | None = None From 4412982b241d24a48d6fddcb13b59101479370be Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 25 Jun 2026 07:40:23 -0700 Subject: [PATCH 136/141] fix(go/plugins/middleware/exp): address PR review feedback - nil-guard typed-nil *AgentMetadata in isClientManaged - document that unknown/absent agent metadata is treated as not client-managed (intentionally stricter than the JS middleware) - align doc-example import alias with the sample (middlewarex) - skip nil artifacts when building the artifacts listing - count artifact size in runes, not bytes, to match the "chars" label --- go/plugins/middleware/exp/agents.go | 15 +++++++++++---- go/plugins/middleware/exp/artifacts.go | 8 ++++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/go/plugins/middleware/exp/agents.go b/go/plugins/middleware/exp/agents.go index 350b9bdce0..c30b28747f 100644 --- a/go/plugins/middleware/exp/agents.go +++ b/go/plugins/middleware/exp/agents.go @@ -107,16 +107,16 @@ func resolveAgent(g *genkit.Genkit, ref aix.AgentRef) (api.BidiAction, error) { // ai.WithModelName("googleai/gemini-flash-latest"), // ai.WithSystem("You are a helpful project assistant."), // ai.WithUse( -// &middleware.Agents{ +// &middlewarex.Agents{ // Agents: []aix.AgentRef{ // {Name: "researcher"}, // by name // coderAgent.Ref(), // by instance (carries its description) // }, // MaxDelegations: 5, // HistoryLength: 4, -// ArtifactStrategy: middleware.ArtifactStrategySession, +// ArtifactStrategy: middlewarex.ArtifactStrategySession, // }, -// &middleware.Artifacts{}, +// &middlewarex.Artifacts{}, // ), // }, // ) @@ -326,6 +326,13 @@ func runSubAgent(ctx context.Context, agent api.BidiAction, task string, history // isClientManaged reports whether the agent owns its state on the client (no // session store), which is the only case that accepts seeded init state. +// +// Unknown or absent agent metadata is treated as not client-managed. That is +// the safe default: it avoids seeding init state into an agent that might +// reject it. This is intentionally stricter than the JS middleware, which +// forwards history unless state management is explicitly "server"; for +// genkit-defined agents the metadata is always set, so the two agree in +// practice. func isClientManaged(agent api.BidiAction) bool { meta := agent.Desc().Metadata if meta == nil { @@ -335,7 +342,7 @@ func isClientManaged(agent api.BidiAction) bool { case aix.AgentMetadata: return m.StateManagement == aix.AgentStateManagementClient case *aix.AgentMetadata: - return m.StateManagement == aix.AgentStateManagementClient + return m != nil && m.StateManagement == aix.AgentStateManagementClient case map[string]any: s, _ := m["stateManagement"].(string) return aix.AgentStateManagement(s) == aix.AgentStateManagementClient diff --git a/go/plugins/middleware/exp/artifacts.go b/go/plugins/middleware/exp/artifacts.go index 642f3cbc3a..33678e492b 100644 --- a/go/plugins/middleware/exp/artifacts.go +++ b/go/plugins/middleware/exp/artifacts.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "strings" + "unicode/utf8" "github.com/firebase/genkit/go/ai" aix "github.com/firebase/genkit/go/ai/exp" @@ -61,7 +62,7 @@ const artifactsMarker = "artifacts-listing" // aix.InlinePrompt{ // ai.WithModelName("googleai/gemini-flash-latest"), // ai.WithSystem("You are a code generator. Use write_artifact to create files."), -// ai.WithUse(&middleware.Artifacts{}), +// ai.WithUse(&middlewarex.Artifacts{}), // }, // ) type Artifacts struct { @@ -181,13 +182,16 @@ func buildArtifactsListing(arts []*aix.Artifact) string { b.WriteString("The following artifacts are available in the session. ") b.WriteString("Use the read_artifact tool to view their content.\n") for _, a := range arts { + if a == nil { + continue + } name := a.Name if name == "" { name = "(unnamed)" } fmt.Fprintf(&b, " - %s", name) if text := artifactText(a); len(text) > 0 { - fmt.Fprintf(&b, " (%d chars)", len(text)) + fmt.Fprintf(&b, " (%d chars)", utf8.RuneCountInString(text)) } if src := artifactSource(a); src != "" { fmt.Fprintf(&b, " [from: %s]", src) From ec6cf16590d60b602e90bfb8395aefcc943c4d28 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 25 Jun 2026 09:05:30 -0700 Subject: [PATCH 137/141] chore(go): move Go sample ignores out of the root .gitignore The root .gitignore carried Go-sample-specific entries that belong with the module (a go/.gitignore already ignores the basic-agents binary): - Drop go/**/.genkit: the repo-root bare `.genkit` rule already ignores .genkit dirs at any depth, so it was redundant. - Drop the stale /go/custom-agent and /go/x-agent-interrupts binary ignores; those samples no longer exist. - In go/.gitignore, generalize the binary comment and add /basic-agents-server alongside /basic-agents. Snapshot artifacts stay covered by the root bare `.genkit` rule. --- .gitignore | 4 ---- go/.gitignore | 5 +++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index de95d32e5a..b68671dea8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,10 +31,6 @@ js/testapps/firebase-functions-sample1/public/config.js .genkit js/**/.genkit samples/**/.genkit -go/**/.genkit -# Compiled Go sample binaries (e.g. `go build ./samples/` writes go/) -/go/custom-agent -/go/x-agent-interrupts ui-debug.log firebase-debug.log firestore-debug.log diff --git a/go/.gitignore b/go/.gitignore index f32f2e12ef..11c35d174d 100644 --- a/go/.gitignore +++ b/go/.gitignore @@ -1,3 +1,4 @@ -# Compiled sample binary produced by `go build ./samples/basic-agents` -# from this directory. It is a build artifact, not source. +# Compiled sample binaries produced by `go build ./samples/` from this +# directory (e.g. ./samples/basic-agents). They are build artifacts, not source. /basic-agents +/basic-agents-server From d8b3a715cf2302d639e15d34b4359b44e54115e9 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 25 Jun 2026 09:38:21 -0700 Subject: [PATCH 138/141] fix(go/ai/exp): expose the typed output schema on exp tools Tools created via NewTool/DefineTool (and the interruptible variants) wrap ai.NewMultipartTool, whose function returns *MultipartToolResponse. The inner ToolDef therefore advertised that envelope ({content, output, metadata}) as the output schema instead of the real Out type, so the schema exposed to the model and Dev UI was wrong. This made the experimental constructors strictly less capable than ai.NewTool, which infers the output schema from Out. Override Tool.Definition to set OutputSchema from the Out type parameter, matching what ai.NewTool exposes. Genkit infers schemas with DoNotReference, so the result is fully inlined and needs no registry resolution, making the override equivalent whether or not the tool is registered. InterruptibleTool embeds Tool, so it inherits the fix. Add TestTool_OutputSchemaMatchesClassic, pinning the exp output schema to ai.NewTool's for both the simple and interruptible constructors (nested struct included to exercise schema inlining). --- go/ai/exp/tools.go | 37 +++++++++++++++++++++++++---- go/ai/exp/tools_test.go | 52 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/go/ai/exp/tools.go b/go/ai/exp/tools.go index 95d579bd27..8f17d3021f 100644 --- a/go/ai/exp/tools.go +++ b/go/ai/exp/tools.go @@ -20,9 +20,11 @@ import ( "context" "errors" "fmt" + "reflect" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/ai/exp/tool" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/base" ) @@ -45,15 +47,42 @@ type Tool[In, Out any] struct { inner *ai.ToolDef[In, *ai.MultipartToolResponse] // DEPRECATED(breaking): remove wrapper; Tool owns the action directly. } -// DEPRECATED(breaking): All methods below exist only to implement ai.Tool by -// delegating to the wrapped ai.ToolDef. With breaking changes, Tool would own -// the action directly and implement these natively without delegation. +// DEPRECATED(breaking): The methods below exist to implement ai.Tool on top of +// the wrapped ai.ToolDef. Most are pure delegation; Definition additionally +// restores the real output schema (see its comment). With breaking changes, Tool +// would own the action directly and implement these natively, inferring the +// output schema from Out without the override. // Name returns the name of the tool. func (t *Tool[In, Out]) Name() string { return t.inner.Name() } // Definition returns the [ai.ToolDefinition] for this tool. -func (t *Tool[In, Out]) Definition() *ai.ToolDefinition { return t.inner.Definition() } +// +// The inner tool is built on [ai.NewMultipartTool], whose function returns +// *[ai.MultipartToolResponse], so the inner definition would advertise that +// envelope as the output schema. We override OutputSchema with the schema +// inferred from the Out type parameter, making the definition equivalent to what +// [ai.NewTool] exposes (the real output type) rather than leaking the multipart +// envelope to the model and Dev UI. Genkit infers schemas with DoNotReference, +// so the result is fully inlined and needs no registry resolution. +func (t *Tool[In, Out]) Definition() *ai.ToolDefinition { + def := t.inner.Definition() + if schema := inferOutputSchema[Out](); schema != nil { + def.OutputSchema = schema + } + return def +} + +// inferOutputSchema returns the inlined JSON schema for the Out type parameter, +// or nil when Out carries no schema (e.g. any), mirroring how [ai.NewTool] +// derives its output schema from the output type. +func inferOutputSchema[Out any]() map[string]any { + var zero Out + if reflect.TypeOf(zero) == nil { + return nil + } + return core.InferSchemaMap(zero) +} // RunRaw runs the tool with raw input. func (t *Tool[In, Out]) RunRaw(ctx context.Context, input any) (any, error) { diff --git a/go/ai/exp/tools_test.go b/go/ai/exp/tools_test.go index 1214d9ce6a..f33b5e6447 100644 --- a/go/ai/exp/tools_test.go +++ b/go/ai/exp/tools_test.go @@ -18,6 +18,7 @@ package exp import ( "context" + "reflect" "strings" "sync" "testing" @@ -118,6 +119,57 @@ func TestTool_AttachParts(t *testing.T) { } } +type reportItem struct { + Name string `json:"name"` +} + +type reportOut struct { + Title string `json:"title"` + Items []reportItem `json:"items"` +} + +// TestTool_OutputSchemaMatchesClassic guards against the multipart envelope +// leaking into the tool definition. aix.NewTool wraps ai.NewMultipartTool, whose +// function returns *MultipartToolResponse; without Definition restoring the real +// output schema, the model and Dev UI would see the envelope ({content, output, +// metadata}) instead of the actual Out type. The exp tool must advertise the +// same output schema ai.NewTool would, including for the interruptible variant +// (which embeds Tool) and for a nested struct that exercises schema inlining. +func TestTool_OutputSchemaMatchesClassic(t *testing.T) { + classic := ai.NewTool("classic", "d", + func(tc *ai.ToolContext, _ weatherIn) (reportOut, error) { return reportOut{}, nil }) + want := classic.Definition().OutputSchema + if want == nil { + t.Fatal("ai.NewTool unexpectedly produced a nil output schema") + } + + simple := NewTool("exp-simple", "d", + func(ctx context.Context, _ weatherIn) (reportOut, error) { return reportOut{}, nil }) + interruptible := NewInterruptibleTool("exp-interruptible", "d", + func(ctx context.Context, _ weatherIn, _ *confirmation) (reportOut, error) { return reportOut{}, nil }) + + for _, tc := range []struct { + name string + got any + }{ + {"NewTool", simple.Definition().OutputSchema}, + {"NewInterruptibleTool", interruptible.Definition().OutputSchema}, + } { + if !reflect.DeepEqual(tc.got, want) { + t.Errorf("%s output schema = %#v\nwant %#v (matching ai.NewTool)", tc.name, tc.got, want) + } + // Explicit guard on intent: the real Out fields are present and the + // MultipartToolResponse envelope fields are not. + props, _ := tc.got.(map[string]any)["properties"].(map[string]any) + if _, ok := props["title"]; !ok { + t.Errorf("%s output schema missing the real %q field: %#v", tc.name, "title", tc.got) + } + if _, ok := props["content"]; ok { + t.Errorf("%s output schema leaked the multipart envelope (has %q): %#v", tc.name, "content", tc.got) + } + } +} + // TestTool_SendPartialNoOpWithoutStreaming confirms SendPartial is a safe no-op // when no streaming callback is wired (here, a direct RunRaw). func TestTool_SendPartialNoOpWithoutStreaming(t *testing.T) { From cd34ade9166e41944070f3401470e4f610c4c5b4 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 25 Jun 2026 09:38:28 -0700 Subject: [PATCH 139/141] refactor(go/plugins/middleware/exp): build tools with the ai/exp tool API The Agents and Artifacts middleware defined their tools with ai.NewTool, taking an *ai.ToolContext. Switch them to aix.NewTool, the experimental constructor in go/ai/exp, which takes a plain context.Context. All three tools (delegation, read_artifact, write_artifact) only ever used the ToolContext as a context for store and agent resolution, so the simpler signature is a clean fit and they need none of the legacy ToolContext fields. Behavior is unchanged: the tool results the model sees and the session artifact operations are identical; only the construction API differs. --- go/plugins/middleware/exp/agents.go | 17 ++++++++++------- go/plugins/middleware/exp/artifacts.go | 12 ++++++------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/go/plugins/middleware/exp/agents.go b/go/plugins/middleware/exp/agents.go index b749eecfd9..0ddd4acf5f 100644 --- a/go/plugins/middleware/exp/agents.go +++ b/go/plugins/middleware/exp/agents.go @@ -180,7 +180,7 @@ func (a *Agents) New(ctx context.Context) (*ai.Hooks, error) { if desc == "" { desc = fmt.Sprintf("Delegates a task to the %q sub-agent.", ref.Name) } - tools = append(tools, ai.NewTool(makeToolName(prefix, ref.Name), desc, a.delegate(ref, st))) + tools = append(tools, aix.NewTool(makeToolName(prefix, ref.Name), desc, a.delegate(ref, st))) } wrapGenerate := func(ctx context.Context, params *ai.GenerateParams, next ai.GenerateNext) (*ai.ModelResponse, error) { @@ -222,9 +222,12 @@ type delegatedArtifact struct { Content string `json:"content,omitempty"` } -// delegate builds the delegation tool function for one sub-agent. -func (a *Agents) delegate(ref aix.AgentRef, st *agentsState) func(*ai.ToolContext, delegateInput) (delegationResult, error) { - return func(tc *ai.ToolContext, in delegateInput) (delegationResult, error) { +// delegate builds the delegation tool function for one sub-agent. The function +// uses the experimental [aix.NewTool] signature: a plain [context.Context] +// rather than an [ai.ToolContext], since delegation needs only the context for +// agent resolution, sub-agent execution, and artifact merging. +func (a *Agents) delegate(ref aix.AgentRef, st *agentsState) func(context.Context, delegateInput) (delegationResult, error) { + return func(ctx context.Context, in delegateInput) (delegationResult, error) { // Guard rail: enforce the delegation cap and reserve this delegation's // number, atomically, before doing any work. st.mu.Lock() @@ -239,7 +242,7 @@ func (a *Agents) delegate(ref aix.AgentRef, st *agentsState) func(*ai.ToolContex history := recentTextHistory(st.conversation, a.HistoryLength) st.mu.Unlock() - agent, err := resolveAgent(genkit.FromContext(tc), ref) + agent, err := resolveAgent(genkit.FromContext(ctx), ref) if err != nil { return delegationResult{Response: "Error: " + err.Error()}, nil } @@ -250,7 +253,7 @@ func (a *Agents) delegate(ref aix.AgentRef, st *agentsState) func(*ai.ToolContex history = nil } - out, err := runSubAgent(tc, agent, in.Task, history) + out, err := runSubAgent(ctx, agent, in.Task, history) if err != nil { // The agent runtime resolves failures and interrupts gracefully (see // below), so this only fires for exceptions outside that handling @@ -284,7 +287,7 @@ func (a *Agents) delegate(ref aix.AgentRef, st *agentsState) func(*ai.ToolContex invocationID := fmt.Sprintf("%s_%d", ref.Name, invocationNum) // Merge into the parent session under both strategies (no-op if there // is no active session, e.g. a plain genkit.Generate call). - mergeArtifacts(tc, ref.Name, invocationID, subArtifacts) + mergeArtifacts(ctx, ref.Name, invocationID, subArtifacts) result.Artifacts = delegatedArtifacts(invocationID, subArtifacts, a.strategy()) } return result, nil diff --git a/go/plugins/middleware/exp/artifacts.go b/go/plugins/middleware/exp/artifacts.go index 33c5bed8ad..4aa2d69b59 100644 --- a/go/plugins/middleware/exp/artifacts.go +++ b/go/plugins/middleware/exp/artifacts.go @@ -107,11 +107,11 @@ type readArtifactOutput struct { } func newReadArtifactTool() ai.Tool { - return ai.NewTool("read_artifact", + return aix.NewTool("read_artifact", "Reads the content of a named artifact from the session. Use this to "+ "inspect artifacts produced by sub-agents or previously created artifacts.", - func(tc *ai.ToolContext, in readArtifactInput) (readArtifactOutput, error) { - store := aix.ArtifactStoreFromContext(tc) + func(ctx context.Context, in readArtifactInput) (readArtifactOutput, error) { + store := aix.ArtifactStoreFromContext(ctx) if store == nil { return readArtifactOutput{Name: in.Name, Content: "Error: no active session.", Found: false}, nil } @@ -134,12 +134,12 @@ type writeArtifactOutput struct { } func newWriteArtifactTool() ai.Tool { - return ai.NewTool("write_artifact", + return aix.NewTool("write_artifact", "Creates or updates a named artifact in the session. If an artifact with "+ "the same name already exists, it is replaced. Use this to produce "+ "files, reports, code, or other deliverables.", - func(tc *ai.ToolContext, in writeArtifactInput) (writeArtifactOutput, error) { - store := aix.ArtifactStoreFromContext(tc) + func(ctx context.Context, in writeArtifactInput) (writeArtifactOutput, error) { + store := aix.ArtifactStoreFromContext(ctx) if store == nil { return writeArtifactOutput{Status: "Error: no active session."}, nil } From 898a9280a573e47c49b6d827249a6c7599aa81f8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 25 Jun 2026 10:53:57 -0700 Subject: [PATCH 140/141] refactor(go/ai/exp): seed the genkit instance internally, drop public WithContextFunc The WithContextFunc agent option existed solely so genkit/exp's constructors could seed the *genkit.Genkit into each agent turn, letting middleware reach it via genkit.FromContext. Exposing that as a public option leaked an internal seam. Move the seeding down into ai/exp's registry-level constructors: they now derive the decorator from the registry they already receive, through a new internal bridge hook (genkitbridge.SeedContextForRegistry) that the genkit package installs and backs by reconstructing &Genkit{reg}. genkit/exp no longer injects anything and WithContextFunc is removed from the public API; the contextFunc field remains an unexported implementation detail. The hook is nil-safe: an agent defined on a bare registry without the genkit package linked carries no decorator. The genkit.Generate seeding path is untouched, so middleware attached to a direct genkit.Generate call behaves as before. --- go/ai/exp/agent.go | 25 ++++++++++++++++++ go/ai/exp/option.go | 33 +++++++----------------- go/genkit/bridge.go | 11 +++++--- go/genkit/exp/agent.go | 15 ----------- go/internal/genkitbridge/genkitbridge.go | 19 ++++++++------ 5 files changed, 53 insertions(+), 50 deletions(-) diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go index aab9d6ae12..8b68243e79 100644 --- a/go/ai/exp/agent.go +++ b/go/ai/exp/agent.go @@ -39,6 +39,7 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/base" + "github.com/firebase/genkit/go/internal/genkitbridge" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -727,6 +728,24 @@ func DefineAgent[State any]( return DefineCustomAgent(r, name, agentLoop[State](r, p, nil), opts...) } +// genkitContextSeed returns a context decorator that seeds the host Genkit +// instance into each agent invocation, so the agent's prompt, tools, and +// middleware can retrieve it via genkit.FromContext and resolve or run other +// actions. The instance is reconstructed from r by the genkit package through +// the internal bridge, so the registry-level constructors below wire seeding up +// themselves and there is no public option for it. +// +// It returns nil when the genkit package is not linked into the build, leaving +// an agent defined directly on a bare registry untouched. +func genkitContextSeed(r api.Registry) func(context.Context) context.Context { + if genkitbridge.SeedContextForRegistry == nil { + return nil + } + return func(ctx context.Context) context.Context { + return genkitbridge.SeedContextForRegistry(ctx, r) + } +} + // DefinePromptAgent defines a prompt-backed agent and registers it, sourcing // its prompt from the registry by name. Each turn renders the prompt, appends // conversation history, calls the model with streaming, and updates session @@ -754,6 +773,9 @@ func DefinePromptAgent[State any]( name string, opts ...PromptAgentOption[State], ) *Agent[State] { + if seed := genkitContextSeed(r); seed != nil { + opts = append(opts, &agentOptions[State]{contextFunc: seed}) + } cfg := &promptAgentOptions[State]{} for _, opt := range opts { if err := opt.applyPromptAgent(cfg); err != nil { @@ -886,6 +908,9 @@ func DefineCustomAgent[State any]( fn AgentFunc[State], opts ...AgentOption[State], ) *Agent[State] { + if seed := genkitContextSeed(r); seed != nil { + opts = append(opts, &agentOptions[State]{contextFunc: seed}) + } a := NewCustomAgent(name, fn, opts...) a.Register(r) return a diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index b68eaab1ed..797e7b882c 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -111,7 +111,11 @@ type agentOptions[State any] struct { transform StateTransform[State] streamTransform StreamTransform description string - contextFunc func(context.Context) context.Context + // contextFunc decorates each invocation's context once before the turn + // loop runs. It has no public option: the registry-level constructors set + // it internally to seed the genkit instance (see genkitContextSeed in + // agent.go), so callers reach the instance via genkit.FromContext. + contextFunc func(context.Context) context.Context } func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { @@ -140,16 +144,10 @@ func (o *agentOptions[State]) applyAgent(opts *agentOptions[State]) error { opts.description = o.description } if o.contextFunc != nil { - // Context decorators compose rather than conflict: each WithContextFunc - // adds a layer applied in registration order. This lets the genkit/exp - // agent constructors seed the genkit instance while applications add - // their own decorators on the same agent. - if prev := opts.contextFunc; prev != nil { - next := o.contextFunc - opts.contextFunc = func(ctx context.Context) context.Context { return next(prev(ctx)) } - } else { - opts.contextFunc = o.contextFunc - } + // Seeded internally by the registry-level constructors + // (genkitContextSeed), at most once per agent, so this is a plain set + // rather than a compose of multiple decorators. + opts.contextFunc = o.contextFunc } return nil } @@ -224,19 +222,6 @@ func WithDescription[State any](description string) AgentOption[State] { return &agentOptions[State]{description: description} } -// WithContextFunc registers a function that decorates the context for each -// agent invocation, applied once before the turn loop runs so the returned -// context flows to the prompt, tools, and middleware of every turn. -// -// The genkit package uses it to seed the [genkit.Genkit] instance (retrievable -// with genkit.FromContext) so middleware can resolve and run other actions -// without direct registry access. Applications may also use it to attach -// invocation-scoped values (e.g. request identity). Multiple decorators compose -// in registration order. -func WithContextFunc[State any](fn func(context.Context) context.Context) AgentOption[State] { - return &agentOptions[State]{contextFunc: fn} -} - // WithNamedPrompt points a [DefinePromptAgent] at the prompt registered under // name, rendered with input on every turn (pass nil for the prompt's own // default input). name need not match the agent's name, so a single registered diff --git a/go/genkit/bridge.go b/go/genkit/bridge.go index 5f8795a592..1947168462 100644 --- a/go/genkit/bridge.go +++ b/go/genkit/bridge.go @@ -21,17 +21,22 @@ import ( "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/genkitbridge" + "github.com/firebase/genkit/go/internal/registry" ) // Expose first-party hooks into a *Genkit to subpackages (genkit/exp) without // adding public accessors. genkitbridge lives under go/internal, so only code // inside the Genkit module can read it. See [genkitbridge.RegistryOf] and -// [genkitbridge.SeedContext]. +// [genkitbridge.SeedContextForRegistry]. func init() { genkitbridge.RegistryOf = func(host any) api.Registry { return host.(*Genkit).reg } - genkitbridge.SeedContext = func(ctx context.Context, host any) context.Context { - return genkitCtxKey.NewContext(ctx, host.(*Genkit)) + genkitbridge.SeedContextForRegistry = func(ctx context.Context, reg api.Registry) context.Context { + r, ok := reg.(*registry.Registry) + if !ok { + return ctx + } + return genkitCtxKey.NewContext(ctx, &Genkit{reg: r}) } } diff --git a/go/genkit/exp/agent.go b/go/genkit/exp/agent.go index e9f774822c..5af9412cda 100644 --- a/go/genkit/exp/agent.go +++ b/go/genkit/exp/agent.go @@ -17,7 +17,6 @@ package exp import ( - "context" "sort" aix "github.com/firebase/genkit/go/ai/exp" @@ -26,17 +25,6 @@ import ( "github.com/firebase/genkit/go/internal/genkitbridge" ) -// seedGenkitContext returns an agent option that seeds g into each agent -// invocation's context, so middleware and other code can retrieve it via -// [genkit.FromContext] during the agent's turns, just as [genkit.Generate] -// seeds it. Agents run on the registry alone, so without this the Genkit -// instance would be absent from an agent's turn context. -func seedGenkitContext[State any](g *genkit.Genkit) aix.AgentOption[State] { - return aix.WithContextFunc[State](func(ctx context.Context) context.Context { - return genkitbridge.SeedContext(ctx, g) - }) -} - // DefineAgent defines an agent backed by an inline prompt and registers it as // an action on the registry. Returns an [aix.Agent]. // @@ -83,7 +71,6 @@ func DefineAgent[State any]( prompt aix.InlinePrompt, opts ...aix.AgentOption[State], ) *aix.Agent[State] { - opts = append(opts, seedGenkitContext[State](g)) return aix.DefineAgent(genkitbridge.RegistryOf(g), name, prompt, opts...) } @@ -129,7 +116,6 @@ func DefinePromptAgent[State any]( name string, opts ...aix.PromptAgentOption[State], ) *aix.Agent[State] { - opts = append(opts, seedGenkitContext[State](g)) return aix.DefinePromptAgent(genkitbridge.RegistryOf(g), name, opts...) } @@ -197,7 +183,6 @@ func DefineCustomAgent[State any]( fn aix.AgentFunc[State], opts ...aix.AgentOption[State], ) *aix.Agent[State] { - opts = append(opts, seedGenkitContext[State](g)) return aix.DefineCustomAgent(genkitbridge.RegistryOf(g), name, fn, opts...) } diff --git a/go/internal/genkitbridge/genkitbridge.go b/go/internal/genkitbridge/genkitbridge.go index 2d41799151..a8fb92fed1 100644 --- a/go/internal/genkitbridge/genkitbridge.go +++ b/go/internal/genkitbridge/genkitbridge.go @@ -38,12 +38,15 @@ import ( // install the extractor). First-party callers always pass a *genkit.Genkit. var RegistryOf func(host any) api.Registry -// SeedContext returns ctx with host (a *genkit.Genkit) attached so it can be -// retrieved with genkit.FromContext. It is installed by the genkit package's -// init and used by genkit/exp's agent constructors to seed the Genkit instance -// into every agent turn, so middleware can resolve and run other actions -// without direct registry access. +// SeedContextForRegistry returns ctx with the *genkit.Genkit backing reg +// attached, so it can be retrieved with genkit.FromContext. It is installed by +// the genkit package's init and called by ai/exp's agent constructors to seed +// the Genkit instance into every agent turn, so an agent's prompt, tools, and +// middleware can resolve and run other actions without direct registry access. // -// As with [RegistryOf], host is typed as any to avoid an import cycle with -// genkit; first-party callers always pass a *genkit.Genkit. -var SeedContext func(ctx context.Context, host any) context.Context +// The Genkit instance is reconstructed from reg (a *genkit.Genkit is a thin +// wrapper over its registry), so ai/exp need not hold a *genkit.Genkit itself +// and the registry-level agent constructors stay genkit-agnostic. It is nil +// until the genkit package is linked into the build; ai/exp treats a nil hook +// as "no seeding", leaving agents defined on a bare registry untouched. +var SeedContextForRegistry func(ctx context.Context, reg api.Registry) context.Context From 9102800cb08c2b151bc0f0aa28d9682396cfe9bb Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 25 Jun 2026 11:05:06 -0700 Subject: [PATCH 141/141] docs(go): document the experimental sub-agent delegation middleware Add a "Delegate to Sub-Agents" example to the Agents section of the Go README, showing the experimental Agents middleware (plugins/middleware/exp): defining sub-agents with descriptions, wiring an orchestrator with ai.WithUse, the key knobs (MaxDelegations, HistoryLength), and how it composes with the Artifacts middleware. Mirrors the orchestrator in samples/basic-agents. --- go/README.md | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/go/README.md b/go/README.md index fd567be815..8d73fa43d3 100644 --- a/go/README.md +++ b/go/README.md @@ -298,6 +298,59 @@ Detach requires a store that implements `SnapshotSubscriber` (both bundled local [See full example](samples/basic-agents) +### Delegate to Sub-Agents + +The experimental `Agents` middleware (in `plugins/middleware/exp`) lets one agent delegate to others. It injects one `delegate_to_` tool per sub-agent and a `` listing into the orchestrator's system prompt, then runs the chosen sub-agent and returns its result when the model calls the tool. Each sub-agent's `aix.WithDescription` (captured by `agent.Ref()`) tells the orchestrator when to reach for it: + +```go +import ( + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/ai/exp/localstore" + genkitx "github.com/firebase/genkit/go/genkit/exp" + middlewarex "github.com/firebase/genkit/go/plugins/middleware/exp" +) + +researcher := genkitx.DefineAgent(g, "researcher", + aix.InlinePrompt{ + ai.WithModelName("googleai/gemini-flash-latest"), + ai.WithSystem("You are a thorough research assistant. Summarize well-sourced findings."), + }, + aix.WithDescription[any]("Researches a topic and summarizes well-sourced findings."), +) + +engineer := genkitx.DefineAgent(g, "engineer", + aix.InlinePrompt{ + ai.WithModelName("googleai/gemini-flash-latest"), + ai.WithSystem("You are an expert programmer. Write clean, well-commented code."), + }, + aix.WithDescription[any]("Writes and explains code."), +) + +// The orchestrator delegates instead of answering directly: the model calls +// delegate_to_researcher / delegate_to_engineer and the middleware runs them. +orchestrator := genkitx.DefineAgent(g, "orchestrator", + aix.InlinePrompt{ + ai.WithModelName("googleai/gemini-flash-latest"), + ai.WithSystem("You are a project coordinator. Delegate to the right sub-agent, " + + "then synthesize a final answer."), + ai.WithUse(&middlewarex.Agents{ + Agents: []aix.AgentRef{researcher.Ref(), engineer.Ref()}, + MaxDelegations: 5, // cap delegation tool calls per turn (0 = unlimited) + HistoryLength: 4, // recent messages forwarded to client-managed sub-agents + }), + }, + aix.WithSessionStore(localstore.NewInMemorySessionStore[any]()), +) + +out, _ := orchestrator.RunText(ctx, "Research goroutine scheduling, then sketch a worker pool.") +fmt.Println(out.Message.Text()) +``` + +Sub-agents are named by `aix.AgentRef`, either captured from an agent value with `agent.Ref()` or written by hand (`aix.AgentRef{Name: "researcher"}`). The middleware composes with the `Artifacts` middleware: give the sub-agents `&middlewarex.Artifacts{}` so they can save output, set `ArtifactStrategy: middlewarex.ArtifactStrategySession` to merge those artifacts into the orchestrator's session instead of inlining them in the tool result, and add `&middlewarex.Artifacts{Readonly: true}` on the orchestrator so it can review them before answering. + +[See full example](samples/basic-agents) + ### Serve Agents over HTTP An `Agent` is an `api.BidiAction`, so it serves over HTTP one turn per request. The `genkit/exp` package lays out a default route surface for every registered agent, including the snapshot companion endpoints for store-backed agents: