From 6f2f66882cc2c775541f9294703465307d9cd6f9 Mon Sep 17 00:00:00 2001 From: Dean Eckert Date: Fri, 6 Mar 2026 12:45:05 +0100 Subject: [PATCH] fix(apiproxy): Fixed an issue, when very large tool calls failed --- apiproxy/response.go | 504 +++++++++++++-------- apiproxy/response_sse_nested_usage_test.go | 103 +++++ db/migrate.go | 50 +- nix-modules/devshell.nix | 3 + 4 files changed, 472 insertions(+), 188 deletions(-) diff --git a/apiproxy/response.go b/apiproxy/response.go index e238da8..7aa2262 100644 --- a/apiproxy/response.go +++ b/apiproxy/response.go @@ -11,6 +11,7 @@ import ( db "openai-api-proxy/db" "os" "regexp" + "strconv" "strings" ) @@ -54,12 +55,9 @@ func (rc *ResponseConf) NewResponse(in *http.Response) error { // Create a pipe to intercept the stream without blocking it. // One end goes to the client (via in.Body), the other to our parser. pr, pw := io.Pipe() - tr := io.TeeReader(in.Body, pw) - in.Body = &readCloserWithCallback{ - Reader: tr, - CloseFunc: func() error { - return pw.Close() - }, + in.Body = &teeReadCloser{ + src: in.Body, + sink: pw, } // Parse SSE events in a separate goroutine @@ -338,6 +336,75 @@ func modelAliasFromResponseMap(resp map[string]interface{}) (string, string) { return "", "" } +const defaultSSEMaxLineBytes = 16 * 1024 * 1024 +const defaultSSEEventMaxBytes = 8 * 1024 * 1024 + +var errSSELineTooLong = errors.New("sse line exceeds max length") + +func sseMaxLineBytes() int { + if raw := strings.TrimSpace(os.Getenv("SSE_MAX_LINE_BYTES")); raw != "" { + if n, err := strconv.Atoi(raw); err == nil && n > 0 { + return n + } + } + return defaultSSEMaxLineBytes +} + +func sseEventMaxBytes() int { + if raw := strings.TrimSpace(os.Getenv("SSE_EVENT_MAX_BYTES")); raw != "" { + if n, err := strconv.Atoi(raw); err == nil && n > 0 { + return n + } + } + return defaultSSEEventMaxBytes +} + +func readSSELine(br *bufio.Reader, maxBytes int) (string, error) { + var line bytes.Buffer + for { + chunk, err := br.ReadSlice('\n') + if len(chunk) > 0 { + if line.Len()+len(chunk) > maxBytes { + if !bytes.HasSuffix(chunk, []byte{'\n'}) { + for { + tail, derr := br.ReadSlice('\n') + if len(tail) > 0 && bytes.HasSuffix(tail, []byte{'\n'}) { + break + } + if derr == io.EOF { + break + } + if derr != nil && derr != bufio.ErrBufferFull { + return "", derr + } + } + } + return "", errSSELineTooLong + } + line.Write(chunk) + if bytes.HasSuffix(chunk, []byte{'\n'}) { + break + } + } + + if err == nil { + break + } + if err == bufio.ErrBufferFull { + continue + } + if err == io.EOF { + if line.Len() == 0 { + return "", io.EOF + } + break + } + return "", err + } + + return strings.TrimRight(line.String(), "\r\n"), nil +} + func (rc *ResponseConf) parseSSEStream(r io.Reader, req *http.Request) { // Re-use logic from the previous implementation but for a stream var cumPrompt, cumCompletion, cumCached int @@ -348,210 +415,270 @@ func (rc *ResponseConf) parseSSEStream(r io.Reader, req *http.Request) { wrote := false // Track whether we had to estimate any token usage (e.g., output tokens from accumulated text) estimatedUsed := false + oversizedEventChars := 0 - // Use a scanner to read line by line from the pipe - // SSE events are separated by double newlines - // We need to handle the case where a single data: block is split across lines (though rare in OpenAI) - // or where multiple data: lines exist in one event. - - // A simple approach is to read until we find a blank line or EOF - // then process what we have. - - // However, we can also just read line by line and if it starts with data: process it. - // Most usage info in OpenAI SSE is contained in a single data: line. - - // We'll use a scanner. - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "data:") { - jsonText := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if jsonText == "[DONE]" { - continue + // Parse SSE as events separated by blank lines. We buffer only one event + // and enforce an event-size cap so parser memory is bounded while streaming + // to clients remains lossless. + + processEventData := func(jsonText string) { + if jsonText == "" || jsonText == "[DONE]" { + return + } + var raw map[string]interface{} + if os.Getenv("DEV_LOG_TOKEN_DEBUG") == "1" { + preview := jsonText + if len(preview) > 1024 { + preview = preview[:1024] + "..." } - var raw map[string]interface{} + log.Printf("DEV DEBUG: SSE event received (data line: %s)", preview) + } + if err := json.Unmarshal([]byte(jsonText), &raw); err != nil { if os.Getenv("DEV_LOG_TOKEN_DEBUG") == "1" { - preview := jsonText - if len(preview) > 1024 { - preview = preview[:1024] + "..." - } - log.Printf("DEV DEBUG: SSE event received (data line: %s)", jsonText) - } - if err := json.Unmarshal([]byte(jsonText), &raw); err != nil { - if os.Getenv("DEV_LOG_TOKEN_DEBUG") == "1" { - log.Printf("DEV DEBUG: SSE data-line JSON unmarshal error: %v; json=%s", err, jsonText) - } - continue + log.Printf("DEV DEBUG: SSE data-line JSON unmarshal error: %v; json=%s", err, jsonText) } + return + } - // Extract text for fallback estimation - if t, ok := raw["text"].(string); ok { + // Extract text for fallback estimation + if t, ok := raw["text"].(string); ok { + accumulatedText.WriteString(t) + } else if part, ok := raw["part"].(map[string]interface{}); ok { + if t, ok := part["text"].(string); ok { accumulatedText.WriteString(t) - } else if part, ok := raw["part"].(map[string]interface{}); ok { - if t, ok := part["text"].(string); ok { - accumulatedText.WriteString(t) - } - } else if choices, ok := raw["choices"].([]interface{}); ok && len(choices) > 0 { - if choice, ok := choices[0].(map[string]interface{}); ok { - if delta, ok := choice["delta"].(map[string]interface{}); ok { - if content, ok := delta["content"].(string); ok { - accumulatedText.WriteString(content) - } + } + } else if choices, ok := raw["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + if delta, ok := choice["delta"].(map[string]interface{}); ok { + if content, ok := delta["content"].(string); ok { + accumulatedText.WriteString(content) } } } + } - // Track ID and Model for fallback - if idv, ok := raw["id"].(string); ok && idv != "" { - lastID = idv + // Track ID and Model for fallback + if idv, ok := raw["id"].(string); ok && idv != "" { + lastID = idv + } + if modelv, ok := raw["model"].(string); ok && modelv != "" { + lastModel = modelv + } + + var usageFound bool + var pcount, ccount, cached int + if u, ok := raw["usage"].(map[string]interface{}); ok { + totals, details := parseUsageMap(u) + pcount, ccount, _, cached = extractTokenCounts(totals, details) + cumPrompt = max(cumPrompt, pcount) + cumCompletion = max(cumCompletion, ccount) + cumCached = max(cumCached, cached) + foundAny = true + eventIdx++ + usageFound = true + if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { + log.Printf("DEV LOG: SSE event %d usage: prompt=%d completion=%d cached=%d (top-level)", eventIdx, pcount, ccount, cached) } - if modelv, ok := raw["model"].(string); ok && modelv != "" { - lastModel = modelv + } + // also check nested response.usage + if !usageFound { + if respObj, ok := raw["response"].(map[string]interface{}); ok { + if u2, ok2 := respObj["usage"].(map[string]interface{}); ok2 { + totals, details := parseUsageMap(u2) + pcount, ccount, _, cached = extractTokenCounts(totals, details) + cumPrompt = max(cumPrompt, pcount) + cumCompletion = max(cumCompletion, ccount) + cumCached = max(cumCached, cached) + foundAny = true + eventIdx++ + usageFound = true + if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { + log.Printf("DEV LOG: SSE event %d usage: prompt=%d completion=%d cached=%d (nested response)", eventIdx, pcount, ccount, cached) + } + } + if idv, ok := respObj["id"].(string); ok && idv != "" { + lastID = idv + } + if modelv, ok := respObj["model"].(string); ok && modelv != "" { + lastModel = modelv + } } + } - var usageFound bool - var pcount, ccount, cached int - if u, ok := raw["usage"].(map[string]interface{}); ok { - totals, details := parseUsageMap(u) - pcount, ccount, _, cached = extractTokenCounts(totals, details) - cumPrompt = max(cumPrompt, pcount) - cumCompletion = max(cumCompletion, ccount) - cumCached = max(cumCached, cached) - foundAny = true - eventIdx++ - usageFound = true - if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { - log.Printf("DEV LOG: SSE event %d usage: prompt=%d completion=%d cached=%d (top-level)", eventIdx, pcount, ccount, cached) + // If this event signals completion, write to DB. + if !wrote { + isCompleted := false + if tstr, ok := raw["type"].(string); ok && tstr == "response.completed" { + isCompleted = true + } + // Also check for finished choices in standard chat completion stream + if choices, ok := raw["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + if finish, ok := choice["finish_reason"].(string); ok && finish != "" { + // Not necessarily the very last event if usage follows, + // but it's a signal. + // Actually, we should wait for the usage event if possible. + } } } - // also check nested response.usage - if !usageFound { + + if isCompleted { + respID := lastID + respModel := lastModel if respObj, ok := raw["response"].(map[string]interface{}); ok { - if u2, ok2 := respObj["usage"].(map[string]interface{}); ok2 { - totals, details := parseUsageMap(u2) - pcount, ccount, _, cached = extractTokenCounts(totals, details) - cumPrompt = max(cumPrompt, pcount) - cumCompletion = max(cumCompletion, ccount) - cumCached = max(cumCached, cached) - foundAny = true - eventIdx++ - usageFound = true - if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { - log.Printf("DEV LOG: SSE event %d usage: prompt=%d completion=%d cached=%d (nested response)", eventIdx, pcount, ccount, cached) - } - } if idv, ok := respObj["id"].(string); ok && idv != "" { - lastID = idv + respID = idv } if modelv, ok := respObj["model"].(string); ok && modelv != "" { - lastModel = modelv + respModel = modelv } } - } - // If this event signals completion, write to DB. - if !wrote { - isCompleted := false - if tstr, ok := raw["type"].(string); ok && tstr == "response.completed" { - isCompleted = true - } - // Also check for finished choices in standard chat completion stream - if choices, ok := raw["choices"].([]interface{}); ok && len(choices) > 0 { - if choice, ok := choices[0].(map[string]interface{}); ok { - if finish, ok := choice["finish_reason"].(string); ok && finish != "" { - // Not necessarily the very last event if usage follows, - // but it's a signal. - // Actually, we should wait for the usage event if possible. - } + if respID != "" { + apiKeyID := "" + header := "" + if req != nil { + header = req.Header.Get(authHeader) } - } - - if isCompleted { - respID := lastID - respModel := lastModel - if respObj, ok := raw["response"].(map[string]interface{}); ok { - if idv, ok := respObj["id"].(string); ok && idv != "" { - respID = idv - } - if modelv, ok := respObj["model"].(string); ok && modelv != "" { - respModel = modelv + apiKey := strings.TrimPrefix(header, "Bearer ") + if hashes, err := rc.db.LookupApiKeys("*"); err == nil { + if uid, err := CompareToken(hashes, apiKey); err == nil { + apiKeyID = uid } } - if respID != "" { - apiKeyID := "" - header := "" - if req != nil { - header = req.Header.Get(authHeader) - } - apiKey := strings.TrimPrefix(header, "Bearer ") - if hashes, err := rc.db.LookupApiKeys("*"); err == nil { - if uid, err := CompareToken(hashes, apiKey); err == nil { - apiKeyID = uid - } - } + // Final counts: prefer current event, then cumulative, then estimation + finalPrompt := pcount + finalCompletion := ccount + finalCached := cached - // Final counts: prefer current event, then cumulative, then estimation - finalPrompt := pcount - finalCompletion := ccount - finalCached := cached + if finalPrompt == 0 { + finalPrompt = cumPrompt + } + if finalCompletion == 0 { + finalCompletion = cumCompletion + } + if finalCached == 0 { + finalCached = cumCached + } - if finalPrompt == 0 { - finalPrompt = cumPrompt - } - if finalCompletion == 0 { - finalCompletion = cumCompletion - } - if finalCached == 0 { - finalCached = cumCached + // Estimation fallback for completion tokens. + estimatedChars := accumulatedText.Len() + oversizedEventChars + if finalCompletion == 0 && estimatedChars > 0 { + // Simple heuristic: 1 token ≈ 4 characters + finalCompletion = estimatedChars / 4 + if finalCompletion == 0 && estimatedChars > 0 { + finalCompletion = 1 } - - // Estimation fallback for completion tokens - if finalCompletion == 0 && accumulatedText.Len() > 0 { - // Simple heuristic: 1 token ≈ 4 characters - finalCompletion = accumulatedText.Len() / 4 - if finalCompletion == 0 && accumulatedText.Len() > 0 { - finalCompletion = 1 - } - estimatedUsed = true - if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { - log.Printf("DEV LOG: using estimated completion tokens: %d (chars: %d)", finalCompletion, accumulatedText.Len()) - } + estimatedUsed = true + if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { + log.Printf("DEV LOG: using estimated completion tokens: %d (chars: %d)", finalCompletion, estimatedChars) } + } - modelAlias, snapshot := splitModelSnapshot(respModel) - promptTokens := dedupPromptTokens(finalPrompt, finalCached) - rq := db.Request{ - ID: respID, - ApiKeyID: apiKeyID, - TokenCountPrompt: promptTokens, - TokenCountComplete: finalCompletion, - InputTokenCount: finalPrompt, - CachedInputTokenCount: finalCached, - OutputTokenCount: finalCompletion, - Model: modelAlias, - SnapshotVersion: snapshot, - IsApproximated: estimatedUsed, - } - if err := rc.db.WriteRequest(&rq); err != nil { - log.Printf("DEV LOG: failed to write request for SSE completed id=%s: %v", respID, err) - } else { - if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { - log.Printf("DEV LOG: wrote SSE completed request id=%s prompt=%d completion=%d api_key_id=%s", respID, finalPrompt, finalCompletion, apiKeyID) - } + modelAlias, snapshot := splitModelSnapshot(respModel) + promptTokens := dedupPromptTokens(finalPrompt, finalCached) + rq := db.Request{ + ID: respID, + ApiKeyID: apiKeyID, + TokenCountPrompt: promptTokens, + TokenCountComplete: finalCompletion, + InputTokenCount: finalPrompt, + CachedInputTokenCount: finalCached, + OutputTokenCount: finalCompletion, + Model: modelAlias, + SnapshotVersion: snapshot, + IsApproximated: estimatedUsed, + } + if err := rc.db.WriteRequest(&rq); err != nil { + log.Printf("DEV LOG: failed to write request for SSE completed id=%s: %v", respID, err) + } else { + if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { + log.Printf("DEV LOG: wrote SSE completed request id=%s prompt=%d completion=%d api_key_id=%s", respID, finalPrompt, finalCompletion, apiKeyID) } - wrote = true } + wrote = true } } } } - if err := scanner.Err(); err != nil { - log.Printf("DEV LOG: SSE stream scanner error: %v", err) + maxLineBytes := sseMaxLineBytes() + maxEventBytes := sseEventMaxBytes() + br := bufio.NewReader(r) + var eventData strings.Builder + eventTooLarge := false + eventDataBytes := 0 + flushEvent := func() { + if eventTooLarge { + oversizedEventChars += eventDataBytes + estimatedUsed = true + if os.Getenv("DEV_LOG_TOKEN_COUNT") == "1" { + log.Printf("DEV LOG: SSE event exceeded cap (%d bytes), token accounting approximated", maxEventBytes) + } + } else if eventData.Len() > 0 { + processEventData(eventData.String()) + } + eventData.Reset() + eventTooLarge = false + eventDataBytes = 0 } + for { + line, err := readSSELine(br, maxLineBytes) + if err == io.EOF { + break + } + if err != nil { + if errors.Is(err, errSSELineTooLong) { + eventTooLarge = true + eventDataBytes += maxLineBytes + estimatedUsed = true + log.Printf("DEV LOG: oversized SSE line encountered (max=%d bytes); token accounting may be approximated", maxLineBytes) + continue + } + log.Printf("DEV LOG: SSE stream read error: %v", err) + break + } + + trimmed := strings.TrimSpace(line) + if trimmed == "" { + flushEvent() + continue + } + + if !strings.HasPrefix(trimmed, "data:") { + continue + } + jsonText := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if jsonText == "" { + continue + } + + if eventTooLarge { + eventDataBytes += len(jsonText) + continue + } + + nextBytes := eventDataBytes + len(jsonText) + if eventData.Len() > 0 { + nextBytes++ + } + if nextBytes > maxEventBytes { + eventTooLarge = true + eventDataBytes = nextBytes + eventData.Reset() + continue + } + + if eventData.Len() > 0 { + eventData.WriteByte('\n') + } + eventData.WriteString(jsonText) + eventDataBytes = nextBytes + } + flushEvent() + if !wrote && lastID != "" { // Try fallback if we haven't written yet (e.g. stream ended without response.completed but we have an ID) apiKeyID := "" @@ -571,9 +698,10 @@ func (rc *ResponseConf) parseSSEStream(r io.Reader, req *http.Request) { finalCached := cumCached estimated := false - if finalCompletion == 0 && accumulatedText.Len() > 0 { - finalCompletion = accumulatedText.Len() / 4 - if finalCompletion == 0 && accumulatedText.Len() > 0 { + estimatedChars := accumulatedText.Len() + oversizedEventChars + if finalCompletion == 0 && estimatedChars > 0 { + finalCompletion = estimatedChars / 4 + if finalCompletion == 0 && estimatedChars > 0 { finalCompletion = 1 } estimated = true @@ -607,19 +735,31 @@ func (rc *ResponseConf) parseSSEStream(r io.Reader, req *http.Request) { } } -type readCloserWithCallback struct { - io.Reader - CloseFunc func() error +// teeReadCloser mirrors bytes from src to sink for side-channel parsing. +// Sink write errors are intentionally swallowed so client streaming remains lossless. +type teeReadCloser struct { + src io.ReadCloser + sink io.WriteCloser } -func (r *readCloserWithCallback) Close() error { - if r.CloseFunc != nil { - return r.CloseFunc() +func (t *teeReadCloser) Read(p []byte) (int, error) { + n, err := t.src.Read(p) + if n > 0 && t.sink != nil { + if _, werr := t.sink.Write(p[:n]); werr != nil { + log.Printf("DEV LOG: SSE parser tap disabled after sink write error: %v", werr) + _ = t.sink.Close() + t.sink = nil + } } - if closer, ok := r.Reader.(io.Closer); ok { - return closer.Close() + return n, err +} + +func (t *teeReadCloser) Close() error { + if t.sink != nil { + _ = t.sink.Close() + t.sink = nil } - return nil + return t.src.Close() } func parseUsageMap(raw map[string]interface{}) (map[string]int, map[string]map[string]int) { diff --git a/apiproxy/response_sse_nested_usage_test.go b/apiproxy/response_sse_nested_usage_test.go index 1204d91..638a753 100644 --- a/apiproxy/response_sse_nested_usage_test.go +++ b/apiproxy/response_sse_nested_usage_test.go @@ -89,6 +89,109 @@ func TestNewResponse_SSENestedUsage(t *testing.T) { } } +func TestNewResponse_SSEOverlongLineDoesNotAbortStream(t *testing.T) { + os.Setenv("SSE_MAX_LINE_BYTES", "256") + defer os.Unsetenv("SSE_MAX_LINE_BYTES") + + longChunk := strings.Repeat("x", 512) + sse := strings.Join([]string{ + "event: response.output_text.delta\n", + `data: {"type":"response.output_text.delta","text":"` + longChunk + `"}` + "\n\n", + "event: response.completed\n", + `data: {"type":"response.completed","response":{"id":"r-long","object":"response","usage":{"prompt_tokens":3,"completion_tokens":4}}}` + "\n\n", + }, "") + + req, _ := http.NewRequest("POST", "https://example.local/openai/responses", nil) + req.Header.Set(authHeader, "Bearer TESTTOKEN") + resp := &http.Response{ + Request: req, + Header: http.Header{"Content-Type": []string{"text/event-stream; charset=utf-8"}}, + Body: io.NopCloser(strings.NewReader(sse)), + } + + fb := &fakeDBForTest{} + hash, _ := bcrypt.GenerateFromPassword([]byte("TESTTOKEN"), 5) + fb.apiKeys = []db.ApiKey{{UUID: "uid-1", ApiKey: string(hash), Owner: "owner1"}} + + rc := &ResponseConf{db: fb} + if err := rc.NewResponse(resp); err != nil { + t.Fatalf("NewResponse error: %v", err) + } + + if _, err := io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("stream relay failed: %v", err) + } + resp.Body.Close() + + for i := 0; i < 10 && len(fb.writes) == 0; i++ { + time.Sleep(50 * time.Millisecond) + } + if len(fb.writes) != 1 { + t.Fatalf("expected 1 DB write, got %d", len(fb.writes)) + } + w := fb.writes[0] + if w.ID != "r-long" || w.TokenCountPrompt != 3 || w.TokenCountComplete != 4 { + t.Fatalf("unexpected DB write: %+v", w) + } +} + +func TestNewResponse_SSEOverCapEventApproximatesCompletionTokens(t *testing.T) { + os.Setenv("SSE_EVENT_MAX_BYTES", "256") + defer os.Unsetenv("SSE_EVENT_MAX_BYTES") + os.Setenv("SSE_MAX_LINE_BYTES", "4096") + defer os.Unsetenv("SSE_MAX_LINE_BYTES") + + longChunk := strings.Repeat("z", 1200) + sse := strings.Join([]string{ + "event: response.output_text.delta\n", + `data: {"type":"response.output_text.delta","id":"too-big-1","text":"` + longChunk + `","usage":{"completion_tokens":999}}` + "\n\n", + "event: response.completed\n", + `data: {"type":"response.completed","response":{"id":"r-approx","object":"response","model":"gpt-4o-2025-01-01"}}` + "\n\n", + }, "") + + req, _ := http.NewRequest("POST", "https://example.local/openai/responses", nil) + req.Header.Set(authHeader, "Bearer TESTTOKEN") + resp := &http.Response{ + Request: req, + Header: http.Header{"Content-Type": []string{"text/event-stream; charset=utf-8"}}, + Body: io.NopCloser(strings.NewReader(sse)), + } + + fb := &fakeDBForTest{} + hash, _ := bcrypt.GenerateFromPassword([]byte("TESTTOKEN"), 5) + fb.apiKeys = []db.ApiKey{{UUID: "uid-1", ApiKey: string(hash), Owner: "owner1"}} + + rc := &ResponseConf{db: fb} + if err := rc.NewResponse(resp); err != nil { + t.Fatalf("NewResponse error: %v", err) + } + + if _, err := io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("stream relay failed: %v", err) + } + resp.Body.Close() + + for i := 0; i < 10 && len(fb.writes) == 0; i++ { + time.Sleep(50 * time.Millisecond) + } + if len(fb.writes) != 1 { + t.Fatalf("expected 1 DB write, got %d", len(fb.writes)) + } + w := fb.writes[0] + if w.ID != "r-approx" { + t.Fatalf("unexpected ID: %+v", w) + } + if !w.IsApproximated { + t.Fatalf("expected approximated request, got %+v", w) + } + if w.TokenCountComplete <= 0 { + t.Fatalf("expected approximated completion tokens > 0, got %+v", w) + } + if w.TokenCountComplete == 999 { + t.Fatalf("expected oversized event usage to be ignored under cap, got %+v", w) + } +} + // fake DB implementation for test type fakeDBForTest struct { apiKeys []db.ApiKey diff --git a/db/migrate.go b/db/migrate.go index 79d6464..68b3acb 100644 --- a/db/migrate.go +++ b/db/migrate.go @@ -6,6 +6,8 @@ import ( "log" "os" "regexp" + "strconv" + "time" "ariga.io/atlas-go-sdk/atlasexec" ) @@ -36,11 +38,47 @@ func (d *Database) Migrate() { atlasurl = re.ReplaceAllString(databasePath, `postgres://$2?search_path=public&$3`) // Run `atlas migrate apply` on a PSQL database - res, err := client.MigrateApply(context.Background(), &atlasexec.MigrateApplyParams{ - URL: atlasurl, - }) - if err != nil { - log.Fatalf("failed to apply migrations: %v", err) + maxAttempts := migrationMaxAttempts() + retryDelay := migrationRetryDelay() + var res *atlasexec.MigrateApply + for attempt := 1; attempt <= maxAttempts; attempt++ { + res, err = client.MigrateApply(context.Background(), &atlasexec.MigrateApplyParams{ + URL: atlasurl, + }) + if err == nil { + fmt.Printf("Applied %d migrations\n", len(res.Applied)) + return + } + if attempt < maxAttempts { + log.Printf("migration attempt %d/%d failed: %v; retrying in %s", attempt, maxAttempts, err, retryDelay) + time.Sleep(retryDelay) + } + } + log.Fatalf("failed to apply migrations after %d attempts: %v", maxAttempts, err) +} + +func migrationMaxAttempts() int { + raw := os.Getenv("MIGRATE_MAX_ATTEMPTS") + if raw == "" { + return 30 + } + v, err := strconv.Atoi(raw) + if err != nil || v < 1 { + log.Printf("invalid MIGRATE_MAX_ATTEMPTS=%q; using default 30", raw) + return 30 + } + return v +} + +func migrationRetryDelay() time.Duration { + raw := os.Getenv("MIGRATE_RETRY_DELAY") + if raw == "" { + return 2 * time.Second + } + d, err := time.ParseDuration(raw) + if err != nil || d <= 0 { + log.Printf("invalid MIGRATE_RETRY_DELAY=%q; using default 2s", raw) + return 2 * time.Second } - fmt.Printf("Applied %d migrations\n", len(res.Applied)) + return d } diff --git a/nix-modules/devshell.nix b/nix-modules/devshell.nix index 6625a4a..e38a914 100644 --- a/nix-modules/devshell.nix +++ b/nix-modules/devshell.nix @@ -9,6 +9,9 @@ # Default dev shell pulls in Go tooling and treefmt formatting. devShells.default = pkgs.mkShell { name = "openai-api-proxy-shell"; + buildInputs = [ + pkgs.nodejs_24 + ]; inputsFrom = [ config.devShells.openai-api-proxy-shell config.treefmt.build.devShell