diff --git a/cli.php b/cli.php index 1ec21f30..1e56ca72 100755 --- a/cli.php +++ b/cli.php @@ -10,6 +10,9 @@ * OPENAI_API_KEY=123456 php cli.php 'Your prompt here' --providerId=openai * GOOGLE_API_KEY=123456 OPENAI_API_KEY=123456 php cli.php 'Your prompt here' * + * To stream the response as it arrives, use --outputFormat=stream-text: + * OPENAI_API_KEY=123456 php cli.php 'Your prompt here' --providerId=openai --outputFormat=stream-text + * * For large prompts (e.g., with images), use stdin or file input: * cat prompt.json | php cli.php - --providerId=openai --modelId=gpt-4o * php cli.php @prompt.json --providerId=openai --modelId=gpt-4o @@ -190,7 +193,15 @@ static function ($item) { } try { - if ($outputFormat === 'image-json' || $outputFormat === 'image-base64') { + if ($outputFormat === 'stream-text') { + $stream = $promptBuilder->streamGenerateTextResult(); + foreach ($stream as $chunk) { + echo $chunk->getDeltaText(); + flush(); + } + echo PHP_EOL; + $result = $stream->getFinalResult(); + } elseif ($outputFormat === 'image-json' || $outputFormat === 'image-base64') { $result = $promptBuilder->generateImageResult(); } else { $result = $promptBuilder->generateTextResult(); @@ -204,7 +215,11 @@ static function ($item) { logInfo("Using provider ID: \"{$result->getProviderMetadata()->getId()}\""); logInfo("Using model ID: \"{$result->getModelMetadata()->getId()}\""); +$output = null; switch ($outputFormat) { + case 'stream-text': + // The text was already streamed to stdout above. + break; case 'result-json': $output = json_encode($result, JSON_PRETTY_PRINT); break; @@ -222,4 +237,6 @@ static function ($item) { $output = $result->toText(); } -printOutput($output); +if (is_string($output)) { + printOutput($output); +} diff --git a/phpcs.xml.dist b/phpcs.xml.dist index 2e49db44..f3f26d3e 100644 --- a/phpcs.xml.dist +++ b/phpcs.xml.dist @@ -21,6 +21,9 @@ + + + diff --git a/src/AiClient.php b/src/AiClient.php index ebfeec75..623b78c0 100644 --- a/src/AiClient.php +++ b/src/AiClient.php @@ -15,6 +15,7 @@ use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\ProviderRegistry; use WordPress\AiClient\Results\DTO\GenerativeAiResult; +use WordPress\AiClient\Results\StreamedGenerativeAiResult; /** * Main AI Client class providing both fluent and traditional APIs for AI operations. @@ -290,6 +291,33 @@ public static function generateTextResult( return self::getConfiguredPromptBuilder($prompt, $modelOrConfig, $registry)->generateTextResult(); } + /** + * Streams a text result using the traditional API approach. + * + * Iterate the returned object to consume chunks as they arrive, then call + * its getFinalResult() for the complete result. + * + * @since n.e.x.t + * + * @param Prompt $prompt The prompt content. + * @param ModelInterface|ModelConfig|null $modelOrConfig Optional specific model to use, + * or model configuration for auto-discovery, + * or null for defaults. + * @param ProviderRegistry|null $registry Optional custom registry. If null, uses default. + * @return StreamedGenerativeAiResult The streamed result. + * + * @throws \InvalidArgumentException If the prompt format is invalid. + * @throws \RuntimeException If no suitable model is found or it does not support streaming. + */ + public static function streamGenerateTextResult( + $prompt, + $modelOrConfig = null, + ?ProviderRegistry $registry = null + ): StreamedGenerativeAiResult { + self::validateModelOrConfigParameter($modelOrConfig); + return self::getConfiguredPromptBuilder($prompt, $modelOrConfig, $registry)->streamGenerateTextResult(); + } + /** * Generates an image using the traditional API approach. * diff --git a/src/Builders/PromptBuilder.php b/src/Builders/PromptBuilder.php index 130fc574..c9e56a2d 100644 --- a/src/Builders/PromptBuilder.php +++ b/src/Builders/PromptBuilder.php @@ -9,6 +9,7 @@ use WordPress\AiClient\Common\Exception\RuntimeException; use WordPress\AiClient\Events\AfterGenerateResultEvent; use WordPress\AiClient\Events\BeforeGenerateResultEvent; +use WordPress\AiClient\Events\GenerateResultErrorEvent; use WordPress\AiClient\Files\DTO\File; use WordPress\AiClient\Files\Enums\FileTypeEnum; use WordPress\AiClient\Files\Enums\MediaOrientationEnum; @@ -26,11 +27,13 @@ use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface; +use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\StreamingTextGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface; use WordPress\AiClient\Providers\Models\VideoGeneration\Contracts\VideoGenerationModelInterface; use WordPress\AiClient\Providers\ProviderRegistry; use WordPress\AiClient\Results\DTO\GenerativeAiResult; +use WordPress\AiClient\Results\StreamedGenerativeAiResult; use WordPress\AiClient\Tools\DTO\FunctionDeclaration; use WordPress\AiClient\Tools\DTO\FunctionResponse; use WordPress\AiClient\Tools\DTO\WebSearch; @@ -1048,6 +1051,65 @@ public function generateTextResult(): GenerativeAiResult return $this->generateResult(CapabilityEnum::textGeneration()); } + /** + * Streams a text result from the prompt. + * + * @since n.e.x.t + * + * @return StreamedGenerativeAiResult The streamed result. + * @throws InvalidArgumentException If the prompt or model validation fails. + * @throws RuntimeException If the model does not support streaming text generation. + */ + public function streamGenerateTextResult(): StreamedGenerativeAiResult + { + $this->includeOutputModalities(ModalityEnum::text()); + $this->validateMessages(); + + $capability = CapabilityEnum::textGeneration(); + $model = $this->getConfiguredModel($capability); + + if (!$model instanceof StreamingTextGenerationModelInterface) { + throw new RuntimeException( + sprintf( + 'Model "%s" does not support streaming text generation.', + $model->metadata()->getId() + ) + ); + } + + $messages = $this->messages; + + return $model->streamGenerateTextResult($messages) + ->onStart(function () use ($messages, $model, $capability): void { + $this->dispatchEvent(new BeforeGenerateResultEvent($messages, $model, $capability)); + }) + ->onComplete(function (GenerativeAiResult $result) use ($messages, $model, $capability): void { + $this->dispatchEvent(new AfterGenerateResultEvent($messages, $model, $capability, $result)); + }) + ->onError(function (\Throwable $error) use ($messages, $model, $capability): void { + $this->dispatchEvent(new GenerateResultErrorEvent($messages, $model, $capability, $error)); + }); + } + + /** + * Streams generated text from the prompt as it arrives. + * + * @since n.e.x.t + * + * @return iterable The text deltas, in order. + * @throws InvalidArgumentException If the prompt or model validation fails. + * @throws RuntimeException If the model does not support streaming text generation. + */ + public function streamGenerateText(): iterable + { + foreach ($this->streamGenerateTextResult() as $chunk) { + $delta = $chunk->getDeltaText(); + if ($delta !== '') { + yield $delta; + } + } + } + /** * Generates an image result from the prompt. * diff --git a/src/Events/GenerateResultErrorEvent.php b/src/Events/GenerateResultErrorEvent.php new file mode 100644 index 00000000..6d3ff499 --- /dev/null +++ b/src/Events/GenerateResultErrorEvent.php @@ -0,0 +1,122 @@ + The messages that were sent to the model. + */ + private array $messages; + + /** + * @var ModelInterface The model that processed the prompt. + */ + private ModelInterface $model; + + /** + * @var CapabilityEnum|null The capability that was used for generation. + */ + private ?CapabilityEnum $capability; + + /** + * @var Throwable The error that occurred during generation. + */ + private Throwable $error; + + /** + * Constructor. + * + * @since n.e.x.t + * + * @param list $messages The messages that were sent to the model. + * @param ModelInterface $model The model that processed the prompt. + * @param CapabilityEnum|null $capability The capability that was used for generation. + * @param Throwable $error The error that occurred during generation. + */ + public function __construct( + array $messages, + ModelInterface $model, + ?CapabilityEnum $capability, + Throwable $error + ) { + $this->messages = $messages; + $this->model = $model; + $this->capability = $capability; + $this->error = $error; + } + + /** + * Gets the messages that were sent to the model. + * + * @since n.e.x.t + * + * @return list The messages. + */ + public function getMessages(): array + { + return $this->messages; + } + + /** + * Gets the model that processed the prompt. + * + * @since n.e.x.t + * + * @return ModelInterface The model. + */ + public function getModel(): ModelInterface + { + return $this->model; + } + + /** + * Gets the capability that was used for generation. + * + * @since n.e.x.t + * + * @return CapabilityEnum|null The capability, or null if not specified. + */ + public function getCapability(): ?CapabilityEnum + { + return $this->capability; + } + + /** + * Gets the error that occurred during generation. + * + * @since n.e.x.t + * + * @return Throwable The error. + */ + public function getError(): Throwable + { + return $this->error; + } + + /** + * Performs a deep clone of the event. + * + * @since n.e.x.t + */ + public function __clone() + { + $clonedMessages = []; + foreach ($this->messages as $message) { + $clonedMessages[] = clone $message; + } + $this->messages = $clonedMessages; + } +} diff --git a/src/Providers/Http/DTO/RequestOptions.php b/src/Providers/Http/DTO/RequestOptions.php index 9b101bef..aefd8a49 100644 --- a/src/Providers/Http/DTO/RequestOptions.php +++ b/src/Providers/Http/DTO/RequestOptions.php @@ -17,7 +17,8 @@ * @phpstan-type RequestOptionsArrayShape array{ * timeout?: float|null, * connectTimeout?: float|null, - * maxRedirects?: int|null + * maxRedirects?: int|null, + * stream?: bool|null * } * * @extends AbstractDataTransferObject @@ -27,6 +28,7 @@ class RequestOptions extends AbstractDataTransferObject public const KEY_TIMEOUT = 'timeout'; public const KEY_CONNECT_TIMEOUT = 'connectTimeout'; public const KEY_MAX_REDIRECTS = 'maxRedirects'; + public const KEY_STREAM = 'stream'; /** * @var float|null Maximum duration in seconds to wait for the full response. @@ -43,6 +45,11 @@ class RequestOptions extends AbstractDataTransferObject */ protected ?int $maxRedirects = null; + /** + * @var bool|null Whether to request the response body as a stream. Null is unspecified. + */ + protected ?bool $stream = null; + /** * Sets the request timeout in seconds. * @@ -153,6 +160,31 @@ public function getMaxRedirects(): ?int return $this->maxRedirects; } + /** + * Sets whether to request the response body as a stream. + * + * @since n.e.x.t + * + * @param bool $stream Whether to stream the response body. + * @return void + */ + public function setStream(bool $stream): void + { + $this->stream = $stream; + } + + /** + * Gets whether the response body should be requested as a stream. + * + * @since n.e.x.t + * + * @return bool|null True to stream, false to buffer, or null when unspecified. + */ + public function isStream(): ?bool + { + return $this->stream; + } + /** * {@inheritDoc} * @@ -176,6 +208,10 @@ public function toArray(): array $data[self::KEY_MAX_REDIRECTS] = $this->maxRedirects; } + if ($this->stream !== null) { + $data[self::KEY_STREAM] = $this->stream; + } + return $data; } @@ -200,6 +236,10 @@ public static function fromArray(array $array): self $instance->setMaxRedirects((int) $array[self::KEY_MAX_REDIRECTS]); } + if (isset($array[self::KEY_STREAM])) { + $instance->setStream((bool) $array[self::KEY_STREAM]); + } + return $instance; } @@ -228,6 +268,10 @@ public static function getJsonSchema(): array 'minimum' => 0, 'description' => 'Maximum redirects to follow. 0 disables, null is unspecified.', ], + self::KEY_STREAM => [ + 'type' => ['boolean', 'null'], + 'description' => 'Whether to request the response body as a stream.', + ], ], 'additionalProperties' => false, ]; diff --git a/src/Providers/Http/DTO/Response.php b/src/Providers/Http/DTO/Response.php index 623ab770..14795616 100644 --- a/src/Providers/Http/DTO/Response.php +++ b/src/Providers/Http/DTO/Response.php @@ -4,6 +4,8 @@ namespace WordPress\AiClient\Providers\Http\DTO; +use Nyholm\Psr7\Stream; +use Psr\Http\Message\StreamInterface; use WordPress\AiClient\Common\AbstractDataTransferObject; use WordPress\AiClient\Common\Exception\InvalidArgumentException; use WordPress\AiClient\Providers\Http\Collections\HeadersCollection; @@ -41,9 +43,19 @@ class Response extends AbstractDataTransferObject protected HeadersCollection $headers; /** - * @var string|null The response body. + * @var string|null The response body as a string, once resolved. */ - protected ?string $body; + protected ?string $body = null; + + /** + * @var StreamInterface|null The response body stream, when the response is streamed. + */ + protected ?StreamInterface $stream = null; + + /** + * @var bool Whether the string body has been resolved from the stream. + */ + private bool $bodyResolved; /** * Constructor. @@ -52,11 +64,11 @@ class Response extends AbstractDataTransferObject * * @param int $statusCode The HTTP status code. * @param array> $headers The response headers. - * @param string|null $body The response body. + * @param string|StreamInterface|null $body The response body, as a string or a stream. * * @throws InvalidArgumentException If the status code is invalid. */ - public function __construct(int $statusCode, array $headers, ?string $body = null) + public function __construct(int $statusCode, array $headers, $body = null) { if ($statusCode < 100 || $statusCode >= 600) { throw new InvalidArgumentException('Invalid HTTP status code: ' . $statusCode); @@ -64,14 +76,24 @@ public function __construct(int $statusCode, array $headers, ?string $body = nul $this->statusCode = $statusCode; $this->headers = new HeadersCollection($headers); - $this->body = $body; + + if ($body instanceof StreamInterface) { + $this->stream = $body; + $this->bodyResolved = false; + } else { + $this->body = $body; + $this->bodyResolved = true; + } } /** - * Creates a deep clone of this response. + * Creates a copy of this response. * - * Clones the headers collection to ensure the cloned - * response is independent of the original. + * Headers are cloned so the new response can modify them independently of + * the original. + * + * The body stream is not cloned. Both responses share the same stream + * instance, so consuming it from one response also consumes it from the other. * * @since 0.4.2 */ @@ -132,17 +154,44 @@ public function getHeaderAsString(string $name): ?string } /** - * Gets the response body. + * Gets the response body as a string. + * + * When the response is streamed, this reads the stream to completion, which + * consumes it unless the stream is seekable. * * @since 0.1.0 * - * @return string|null The body. + * @return string|null The body, or null if empty. */ public function getBody(): ?string { + if (!$this->bodyResolved) { + $this->bodyResolved = true; + if ($this->stream !== null) { + $contents = $this->readStream($this->stream); + $this->body = $contents === '' ? null : $contents; + } + } + return $this->body; } + /** + * Gets the response body as a PSR-7 stream. + * + * @since n.e.x.t + * + * @return StreamInterface The body stream. + */ + public function getStream(): StreamInterface + { + if ($this->stream !== null) { + return $this->stream; + } + + return Stream::create($this->body ?? ''); + } + /** * Checks if the response has a header. * @@ -181,11 +230,12 @@ public function isSuccessful(): bool */ public function getData(): ?array { - if ($this->body === null || $this->body === '') { + $body = $this->getBody(); + if ($body === null || $body === '') { return null; } - $data = json_decode($this->body, true); + $data = json_decode($body, true); if (json_last_error() !== JSON_ERROR_NONE) { return null; @@ -231,6 +281,9 @@ public static function getJsonSchema(): array /** * {@inheritDoc} * + * When the response is streamed, this reads the stream + * to serialize the body. + * * @since 0.1.0 * * @return ResponseArrayShape @@ -242,8 +295,9 @@ public function toArray(): array self::KEY_HEADERS => $this->headers->getAll(), ]; - if ($this->body !== null) { - $data[self::KEY_BODY] = $this->body; + $body = $this->getBody(); + if ($body !== null) { + $data[self::KEY_BODY] = $body; } return $data; @@ -267,4 +321,21 @@ public static function fromArray(array $array): self $array[self::KEY_BODY] ?? null ); } + + /** + * Reads a stream to a string, rewinding first when possible. + * + * @since n.e.x.t + * + * @param StreamInterface $stream The stream to read. + * @return string The stream contents. + */ + private function readStream(StreamInterface $stream): string + { + if ($stream->isSeekable()) { + $stream->rewind(); + } + + return $stream->getContents(); + } } diff --git a/src/Providers/Http/Exception/ResponseException.php b/src/Providers/Http/Exception/ResponseException.php index 01e2bd70..5c90ddd2 100644 --- a/src/Providers/Http/Exception/ResponseException.php +++ b/src/Providers/Http/Exception/ResponseException.php @@ -47,4 +47,23 @@ public static function fromInvalidData(string $apiName, string $fieldName, strin { return new self(sprintf('Unexpected %s API response: Invalid "%s" key: %s', $apiName, $fieldName, $message)); } + + /** + * Creates a ResponseException for an error encountered while streaming a response. + * + * @since n.e.x.t + * + * @param string $apiName The name of the API/provider. + * @param string $message The error message. + * @param \Throwable|null $previous The underlying exception, when wrapping one. + * @return self The exception describing the stream error. + */ + public static function fromStreamError(string $apiName, string $message, ?\Throwable $previous = null): self + { + return new self( + sprintf('Error while streaming the %s API response: %s', $apiName, $message), + 0, + $previous + ); + } } diff --git a/src/Providers/Http/HttpTransporter.php b/src/Providers/Http/HttpTransporter.php index 00917cec..f8a63188 100644 --- a/src/Providers/Http/HttpTransporter.php +++ b/src/Providers/Http/HttpTransporter.php @@ -101,7 +101,9 @@ public function send(Request $request, ?RequestOptions $options = null): Respons ); } - return $this->convertFromPsr7Response($psr7Response); + $streaming = $mergedOptions !== null && $mergedOptions->isStream() === true; + + return $this->convertFromPsr7Response($psr7Response, $streaming); } /** @@ -145,6 +147,11 @@ private function mergeOptions(?RequestOptions $requestOptions, ?RequestOptions $ $merged->setMaxRedirects($requestOptions->getMaxRedirects()); } + $requestStream = $requestOptions->isStream(); + if ($requestStream !== null) { + $merged->setStream($requestStream); + } + // Override with parameter options (higher precedence) if ($parameterOptions->getTimeout() !== null) { $merged->setTimeout($parameterOptions->getTimeout()); @@ -158,6 +165,11 @@ private function mergeOptions(?RequestOptions $requestOptions, ?RequestOptions $ $merged->setMaxRedirects($parameterOptions->getMaxRedirects()); } + $parameterStream = $parameterOptions->isStream(); + if ($parameterStream !== null) { + $merged->setStream($parameterStream); + } + return $merged; } @@ -270,6 +282,10 @@ private function buildGuzzleOptions(RequestOptions $options): array } } + if ($options->isStream() === true) { + $guzzleOptions['stream'] = true; + } + return $guzzleOptions; } @@ -313,14 +329,25 @@ private function convertToPsr7Request(Request $request): RequestInterface * @param ResponseInterface $psr7Response The PSR-7 response. * @return Response The custom response. */ - private function convertFromPsr7Response(ResponseInterface $psr7Response): Response + private function convertFromPsr7Response(ResponseInterface $psr7Response, bool $stream): Response { + /** + * PSR-7 always returns headers as arrays, but HeadersCollection handles this. + * + * @var array> $headers + */ + $headers = $psr7Response->getHeaders(); + + // When streaming, pass the body stream through so it is consumed lazily. + if ($stream) { + return new Response($psr7Response->getStatusCode(), $headers, $psr7Response->getBody()); + } + $body = (string) $psr7Response->getBody(); - // PSR-7 always returns headers as arrays, but HeadersCollection handles this return new Response( $psr7Response->getStatusCode(), - $psr7Response->getHeaders(), // @phpstan-ignore-line + $headers, $body === '' ? null : $body ); } diff --git a/src/Providers/Http/Streaming/Contracts/EventStreamParserInterface.php b/src/Providers/Http/Streaming/Contracts/EventStreamParserInterface.php new file mode 100644 index 00000000..b67f38ba --- /dev/null +++ b/src/Providers/Http/Streaming/Contracts/EventStreamParserInterface.php @@ -0,0 +1,26 @@ + The decoded events. + */ + public function parse(StreamInterface $stream): iterable; +} diff --git a/src/Providers/Http/Streaming/SseEventStreamParser.php b/src/Providers/Http/Streaming/SseEventStreamParser.php new file mode 100644 index 00000000..c9dd0e95 --- /dev/null +++ b/src/Providers/Http/Streaming/SseEventStreamParser.php @@ -0,0 +1,253 @@ + The decoded events. + */ + public function parse(StreamInterface $stream): iterable + { + $event = ''; + $data = ''; + $lastId = ''; + $retry = null; + $hasData = false; + + try { + foreach ($this->toLines($stream) as $line) { + // A blank line ends the event. + if ($line === '') { + if ($hasData) { + yield $this->createEvent($event, $data, $lastId, $retry); + } + $event = ''; + $data = ''; + $retry = null; + $hasData = false; + // The last ID persists across events, so it is not reset. + continue; + } + + // Skip comment lines. + if ($line[0] === ':') { + continue; + } + + $colon = strpos($line, ':'); + if ($colon === false) { + $field = $line; + $value = ''; + } else { + $field = (string) substr($line, 0, $colon); + $value = (string) substr($line, $colon + 1); + // Strip one leading space from the value. + if (isset($value[0]) && $value[0] === ' ') { + $value = (string) substr($value, 1); + } + } + + switch ($field) { + case 'event': + $event = $value; + break; + case 'data': + $data .= $value . "\n"; + $hasData = true; + break; + case 'id': + // Ignore IDs that contain a NUL byte. + if (strpos($value, "\0") === false) { + $lastId = $value; + } + break; + case 'retry': + if ($value !== '' && ctype_digit($value)) { + $retry = (int) $value; + } + break; + default: + break; + } + } + + /* + * Per the spec: + * Once the end of the file is reached, any pending data must be discarded. (If the file ends + * in the middle of an event, before the final empty line, the incomplete event is not dispatched.) + * + * @see https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + */ + } finally { + $stream->close(); + } + } + + /** + * Builds an event from the accumulated field state. + * + * @since n.e.x.t + * + * @param string $event The accumulated event name. + * @param string $data The accumulated data buffer (newline-joined). + * @param string $id The current last event ID. + * @param int|null $retry The current reconnection time. + * @return ServerSentEvent The event. + */ + private function createEvent(string $event, string $data, string $id, ?int $retry): ServerSentEvent + { + return new ServerSentEvent( + $event !== '' ? $event : 'message', + (string) substr($data, 0, -1), // data always ends with a newline, so drop it. + $id, + $retry + ); + } + + /** + * Reads the stream and yields complete lines as they become available. + * + * Buffers partial lines across reads, strips a leading BOM, supports the + * `\n`, `\r\n`, and `\r` terminators (including a `\r\n` split across reads), + * and emits any trailing unterminated content once the stream ends. + * + * @since n.e.x.t + * + * @param StreamInterface $stream The response body stream. + * @return \Generator Complete lines, without terminators. + */ + private function toLines(StreamInterface $stream): \Generator + { + $buffer = ''; + $bomChecked = false; + + while (!$stream->eof()) { + $chunk = $stream->read(self::READ_CHUNK_BYTES); + if ($chunk === '') { + continue; + } + + $buffer .= $chunk; + + if (!$bomChecked) { + if (strncmp($buffer, self::BOM, 3) === 0) { + $buffer = substr($buffer, 3); + $bomChecked = true; + } elseif (strlen($buffer) >= 3) { + $bomChecked = true; + } elseif ($buffer === substr(self::BOM, 0, strlen($buffer))) { + // Might be a partial BOM, wait for more bytes. + continue; + } else { + $bomChecked = true; + } + } + + [$lines, $buffer] = $this->extractLines($buffer, false); + foreach ($lines as $line) { + yield $line; + } + } + + [$lines] = $this->extractLines($buffer, true); + foreach ($lines as $line) { + yield $line; + } + } + + /** + * Splits a buffer into complete lines and the unconsumed remainder. + * + * When not at end of stream, a trailing lone `\r` is held back (it may be the + * first half of a `\r\n` arriving next) along with any final unterminated + * line. At end of stream, a trailing `\r` terminates a line and any remaining + * content is emitted as a final line. + * + * @since n.e.x.t + * + * @param string $buffer The byte buffer. + * @param bool $atEof Whether the stream has ended. + * @return array{0: list, 1: string} The complete lines and the remainder. + */ + private function extractLines(string $buffer, bool $atEof): array + { + $lines = []; + $len = strlen($buffer); + $start = 0; + $i = 0; + + while ($i < $len) { + $c = $buffer[$i]; + + if ($c === "\n") { + $lines[] = (string) substr($buffer, $start, $i - $start); + $i++; + $start = $i; + } elseif ($c === "\r") { + if ($i + 1 < $len) { + $lines[] = (string) substr($buffer, $start, $i - $start); + $i += ($buffer[$i + 1] === "\n") ? 2 : 1; + $start = $i; + } elseif ($atEof) { + $lines[] = (string) substr($buffer, $start, $i - $start); + $i++; + $start = $i; + } else { + // A trailing CR might start a split CRLF, so hold it. + break; + } + } else { + $i++; + } + } + + $remaining = (string) substr($buffer, $start); + if ($atEof && $remaining !== '') { + $lines[] = $remaining; + $remaining = ''; + } + + return [$lines, $remaining]; + } +} diff --git a/src/Providers/Http/Streaming/ValueObjects/ServerSentEvent.php b/src/Providers/Http/Streaming/ValueObjects/ServerSentEvent.php new file mode 100644 index 00000000..ecbdd87f --- /dev/null +++ b/src/Providers/Http/Streaming/ValueObjects/ServerSentEvent.php @@ -0,0 +1,104 @@ +event = $event; + $this->data = $data; + $this->id = $id; + $this->retry = $retry; + } + + /** + * Gets the event name. + * + * @since n.e.x.t + * + * @return string The event name ("message" when unspecified). + */ + public function getEvent(): string + { + return $this->event; + } + + /** + * Gets the event payload. + * + * @since n.e.x.t + * + * @return string The payload. + */ + public function getData(): string + { + return $this->data; + } + + /** + * Gets the last event ID. + * + * @since n.e.x.t + * + * @return string The event ID, or an empty string when none was set. + */ + public function getId(): string + { + return $this->id; + } + + /** + * Gets the reconnection time in milliseconds, if the event set one. + * + * Parsed for spec completeness only. The SDK does not reconnect: provider + * streams are one-shot and cannot be resumed, so this value is informational. + * + * @since n.e.x.t + * + * @return int|null The reconnection time, or null when none was set. + */ + public function getRetry(): ?int + { + return $this->retry; + } +} diff --git a/src/Providers/Models/TextGeneration/Contracts/StreamingTextGenerationModelInterface.php b/src/Providers/Models/TextGeneration/Contracts/StreamingTextGenerationModelInterface.php new file mode 100644 index 00000000..bfd58589 --- /dev/null +++ b/src/Providers/Models/TextGeneration/Contracts/StreamingTextGenerationModelInterface.php @@ -0,0 +1,26 @@ + $prompt Array of messages containing the text generation prompt. + * @return StreamedGenerativeAiResult The streamed result. + */ + public function streamGenerateTextResult(array $prompt): StreamedGenerativeAiResult; +} diff --git a/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModel.php b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModel.php index e0e2e71b..ebf9282c 100644 --- a/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModel.php +++ b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModel.php @@ -4,6 +4,7 @@ namespace WordPress\AiClient\Providers\OpenAiCompatibleImplementation; +use Generator; use WordPress\AiClient\Common\Exception\InvalidArgumentException; use WordPress\AiClient\Common\Exception\RuntimeException; use WordPress\AiClient\Messages\DTO\Message; @@ -13,15 +14,23 @@ use WordPress\AiClient\Messages\Enums\ModalityEnum; use WordPress\AiClient\Providers\ApiBasedImplementation\AbstractApiBasedModel; use WordPress\AiClient\Providers\Http\DTO\Request; +use WordPress\AiClient\Providers\Http\DTO\RequestOptions; use WordPress\AiClient\Providers\Http\DTO\Response; use WordPress\AiClient\Providers\Http\Enums\HttpMethodEnum; use WordPress\AiClient\Providers\Http\Exception\ResponseException; +use WordPress\AiClient\Providers\Http\Streaming\Contracts\EventStreamParserInterface; +use WordPress\AiClient\Providers\Http\Streaming\SseEventStreamParser; use WordPress\AiClient\Providers\Http\Util\ResponseUtil; +use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\StreamingTextGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Results\DTO\Candidate; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Results\StreamedGenerativeAiResult; +use WordPress\AiClient\Results\ValueObjects\CandidateDelta; +use WordPress\AiClient\Results\ValueObjects\GenerativeAiResultChunk; +use WordPress\AiClient\Results\ValueObjects\ToolCallDelta; use WordPress\AiClient\Tools\DTO\FunctionCall; use WordPress\AiClient\Tools\DTO\FunctionDeclaration; @@ -55,16 +64,46 @@ * @phpstan-type UsageData array{ * prompt_tokens?: int, * completion_tokens?: int, - * total_tokens?: int + * total_tokens?: int, + * completion_tokens_details?: array{ + * reasoning_tokens?: int + * } * } * @phpstan-type ResponseData array{ * id?: string, * choices?: list, * usage?: UsageData * } + * @phpstan-type StreamToolCallData array{ + * index?: int, + * id?: string, + * type?: string, + * function?: array{ + * name?: string, + * arguments?: string + * } + * } + * @phpstan-type StreamDeltaData array{ + * role?: string, + * reasoning_content?: string, + * reasoning?: string, + * content?: string, + * tool_calls?: list + * } + * @phpstan-type StreamChoiceData array{ + * index?: int, + * delta?: StreamDeltaData, + * finish_reason?: string|null + * } + * @phpstan-type StreamEventData array{ + * id?: string, + * choices?: list, + * usage?: UsageData + * } */ abstract class AbstractOpenAiCompatibleTextGenerationModel extends AbstractApiBasedModel implements - TextGenerationModelInterface + TextGenerationModelInterface, + StreamingTextGenerationModelInterface { /** * {@inheritDoc} @@ -93,6 +132,288 @@ final public function generateTextResult(array $prompt): GenerativeAiResult return $this->parseResponseToGenerativeAiResult($response); } + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function streamGenerateTextResult(array $prompt): StreamedGenerativeAiResult + { + return new StreamedGenerativeAiResult( + $this->streamTextChunks($prompt), + $this->providerMetadata(), + $this->metadata() + ); + } + + /** + * Sends the request and yields result chunks as the event stream arrives. + * + * @since n.e.x.t + * + * @param list $prompt The prompt to generate text for. + * @return Generator The result chunks as they arrive. + */ + private function streamTextChunks(array $prompt): Generator + { + $httpTransporter = $this->getHttpTransporter(); + + $params = $this->prepareGenerateTextParams($prompt); + $params['stream'] = true; + $params['stream_options'] = ['include_usage' => true]; + + $request = $this->createRequest( + HttpMethodEnum::POST(), + 'chat/completions', + ['Content-Type' => 'application/json'], + $params + ); + + $request = $this->getRequestAuthentication()->authenticateRequest($request); + + // Tell the transporter to stream the response body instead of buffering it. + $streamOptions = new RequestOptions(); + $streamOptions->setStream(true); + + $response = $httpTransporter->send($request, $streamOptions); + $this->throwIfNotSuccessful($response); + + try { + foreach ($this->getEventStreamParser()->parse($response->getStream()) as $event) { + $data = $event->getData(); + if ($data === '' || $data === '[DONE]') { + continue; + } + + $decoded = json_decode($data, true); + if (!is_array($decoded)) { + continue; + } + + $this->throwIfStreamError($decoded); + + /** @var StreamEventData $decoded */ + $chunk = $this->parseStreamEvent($decoded); + if ($chunk !== null) { + yield $chunk; + } + } + } catch (ResponseException $e) { + throw $e; + } catch (\RuntimeException $e) { + throw ResponseException::fromStreamError($this->providerMetadata()->getName(), $e->getMessage(), $e); + } + } + + /** + * Throws if a decoded stream event reports a provider error. + * + * @since n.e.x.t + * + * @param array $event The decoded stream event. + * @return void + * + * @throws ResponseException If the event carries an error payload. + */ + protected function throwIfStreamError(array $event): void + { + if (!isset($event['error'])) { + return; + } + + $error = $event['error']; + $message = is_array($error) && isset($error['message']) && is_string($error['message']) + ? $error['message'] + : 'The provider reported an error.'; + + throw ResponseException::fromStreamError($this->providerMetadata()->getName(), $message); + } + + /** + * Maps one decoded stream event into a result chunk. + * + * @since n.e.x.t + * + * @param StreamEventData $data The decoded event payload. + * @return GenerativeAiResultChunk|null The chunk, or null when the event carries nothing usable. + */ + protected function parseStreamEvent(array $data): ?GenerativeAiResultChunk + { + $id = isset($data['id']) && is_string($data['id']) ? $data['id'] : null; + $tokenUsage = isset($data['usage']) && is_array($data['usage']) + ? $this->parseUsageData($data['usage']) + : null; + $additionalData = $this->extractAdditionalData($data); + $choices = isset($data['choices']) && is_array($data['choices']) ? $data['choices'] : []; + + $candidateDeltas = []; + foreach ($choices as $choice) { + if (!is_array($choice)) { + continue; + } + $candidateDeltas[] = $this->parseStreamChoice($choice); + } + + // Skip events that carry no metadata and no candidate deltas. + if ($id === null && $tokenUsage === null && $additionalData === [] && $candidateDeltas === []) { + return null; + } + + return new GenerativeAiResultChunk($id, $tokenUsage, $additionalData, $candidateDeltas); + } + + /** + * Maps one streamed choice into a candidate delta. + * + * @since n.e.x.t + * + * @param StreamChoiceData $choice The choice payload from the event. + * @return CandidateDelta The parsed candidate delta. + */ + protected function parseStreamChoice(array $choice): CandidateDelta + { + $index = isset($choice['index']) && is_int($choice['index']) ? $choice['index'] : 0; + $delta = isset($choice['delta']) && is_array($choice['delta']) ? $choice['delta'] : []; + $finishReason = isset($choice['finish_reason']) && is_string($choice['finish_reason']) + ? FinishReasonEnum::tryFrom($choice['finish_reason']) + : null; + + return new CandidateDelta( + $index, + $this->parseStreamDeltaParts($delta), + $finishReason, + $this->parseStreamToolCallDeltas($delta) + ); + } + + /** + * Maps a streamed delta into message parts. + * + * @since n.e.x.t + * + * @param StreamDeltaData $delta The delta payload from a choice. + * @return list The parsed message parts. + */ + protected function parseStreamDeltaParts(array $delta): array + { + $parts = []; + + if (isset($delta['reasoning_content']) && is_string($delta['reasoning_content'])) { + $parts[] = new MessagePart($delta['reasoning_content'], MessagePartChannelEnum::thought()); + } + + if (isset($delta['reasoning']) && is_string($delta['reasoning'])) { + $parts[] = new MessagePart($delta['reasoning'], MessagePartChannelEnum::thought()); + } + + if (isset($delta['content']) && is_string($delta['content'])) { + $parts[] = new MessagePart($delta['content']); + } + + return $parts; + } + + /** + * Maps a streamed delta's tool calls into tool call fragments. + * + * @since n.e.x.t + * + * @param StreamDeltaData $delta The delta payload from a choice. + * @return list The parsed tool call fragments. + */ + protected function parseStreamToolCallDeltas(array $delta): array + { + if (!isset($delta['tool_calls']) || !is_array($delta['tool_calls'])) { + return []; + } + + $deltas = []; + foreach ($delta['tool_calls'] as $position => $toolCall) { + if (!is_array($toolCall)) { + continue; + } + + // Providers key parallel tool calls by "index"; fall back to position. + $slot = isset($toolCall['index']) && is_int($toolCall['index']) + ? $toolCall['index'] + : (int) $position; + + $id = isset($toolCall['id']) && is_string($toolCall['id']) ? $toolCall['id'] : null; + + $function = isset($toolCall['function']) && is_array($toolCall['function']) + ? $toolCall['function'] + : []; + $name = isset($function['name']) && is_string($function['name']) ? $function['name'] : null; + $arguments = isset($function['arguments']) && is_string($function['arguments']) + ? $function['arguments'] + : ''; + + $deltas[] = new ToolCallDelta($slot, $id, $name, $arguments); + } + + return $deltas; + } + + /** + * Parses usage data into a token usage object. + * + * @since n.e.x.t + * + * @param UsageData $usage The usage payload. + * @return TokenUsage The parsed token usage. + */ + protected function parseUsageData(array $usage): TokenUsage + { + $reasoningTokens = $usage['completion_tokens_details']['reasoning_tokens'] ?? null; + + return new TokenUsage( + $this->toIntOrZero($usage['prompt_tokens'] ?? 0), + $this->toIntOrZero($usage['completion_tokens'] ?? 0), + $this->toIntOrZero($usage['total_tokens'] ?? 0), + is_numeric($reasoningTokens) ? (int) $reasoningTokens : null + ); + } + + /** + * Coerces an untrusted usage value to an integer, defaulting to zero. + * + * @since n.e.x.t + * + * @param mixed $value The raw value from the decoded usage payload. + * @return int The coerced integer, or 0 when the value is not numeric. + */ + private function toIntOrZero($value): int + { + return is_numeric($value) ? (int) $value : 0; + } + + /** + * Extracts provider-specific metadata from a response or stream event. + * + * @since n.e.x.t + * + * @param array $data The decoded response or event payload. + * @return array The remaining provider metadata. + */ + protected function extractAdditionalData(array $data): array + { + unset($data['id'], $data['choices'], $data['usage']); + + return $data; + } + + /** + * Returns the parser used to decode the streamed event stream. + * + * @since n.e.x.t + * + * @return EventStreamParserInterface The event stream parser. + */ + protected function getEventStreamParser(): EventStreamParserInterface + { + return new SseEventStreamParser(); + } + /** * Prepares the given prompt and the model configuration into parameters for the API request. * @@ -601,21 +922,12 @@ protected function parseResponseToGenerativeAiResult(Response $response): Genera $id = isset($responseData['id']) && is_string($responseData['id']) ? $responseData['id'] : ''; - if (isset($responseData['usage']) && is_array($responseData['usage'])) { - $usage = $responseData['usage']; - - $tokenUsage = new TokenUsage( - $usage['prompt_tokens'] ?? 0, - $usage['completion_tokens'] ?? 0, - $usage['total_tokens'] ?? 0 - ); - } else { - $tokenUsage = new TokenUsage(0, 0, 0); - } + $tokenUsage = isset($responseData['usage']) && is_array($responseData['usage']) + ? $this->parseUsageData($responseData['usage']) + : new TokenUsage(0, 0, 0); // Use any other data from the response as provider-specific response metadata. - $additionalData = $responseData; - unset($additionalData['id'], $additionalData['choices'], $additionalData['usage']); + $additionalData = $this->extractAdditionalData($responseData); return new GenerativeAiResult( $id, diff --git a/src/Results/ChunkAccumulator.php b/src/Results/ChunkAccumulator.php new file mode 100644 index 00000000..ed3937e8 --- /dev/null +++ b/src/Results/ChunkAccumulator.php @@ -0,0 +1,380 @@ + Canonical channel order for assembled parts, matching the buffered parser. + */ + private const CANONICAL_CHANNEL_ORDER = [ + MessagePartChannelEnum::THOUGHT, + MessagePartChannelEnum::CONTENT, + ]; + + private ProviderMetadata $providerMetadata; + + private ModelMetadata $modelMetadata; + + /** + * @var string|null The result id, captured from the first chunk that carries one. + */ + private ?string $id = null; + + /** + * @var TokenUsage|null The token usage, captured from the chunk that carries it. + */ + private ?TokenUsage $tokenUsage = null; + + /** + * @var array Merged result-level provider metadata. + */ + private array $additionalData = []; + + /** + * @var array Candidate indices seen while accumulating. + */ + private array $candidates = []; + + /** + * @var array> Accumulated text, per candidate, per channel. + */ + private array $text = []; + + /** + * @var array> Thought signature, per candidate, per channel. + */ + private array $thoughtSignatures = []; + + /** + * @var array> Non-text parts, per candidate, in arrival order. + */ + private array $otherParts = []; + + /** + * @var array Finish reason, per candidate. + */ + private array $finishReasons = []; + + /** + * Tool call slots being stitched together, per candidate, per slot index. + * + * @var array> + */ + private array $toolCalls = []; + + /** + * Constructor. + * + * @since n.e.x.t + * + * @param ProviderMetadata $providerMetadata Provider metadata for the assembled result. + * @param ModelMetadata $modelMetadata Model metadata for the assembled result. + */ + public function __construct(ProviderMetadata $providerMetadata, ModelMetadata $modelMetadata) + { + $this->providerMetadata = $providerMetadata; + $this->modelMetadata = $modelMetadata; + } + + /** + * Folds a single chunk into the accumulated state. + * + * @since n.e.x.t + * + * @param GenerativeAiResultChunk $chunk The chunk to fold in. + * @return void + */ + public function add(GenerativeAiResultChunk $chunk): void + { + $id = $chunk->getId(); + if ($id !== null && $this->id === null) { + $this->id = $id; + } + + $tokenUsage = $chunk->getTokenUsage(); + if ($tokenUsage !== null) { + $this->tokenUsage = $tokenUsage; + } + + $additionalData = $chunk->getAdditionalData(); + if ($additionalData !== []) { + $this->additionalData = array_merge($this->additionalData, $additionalData); + } + + foreach ($chunk->getCandidateDeltas() as $candidateDelta) { + $this->addCandidateDelta($candidateDelta); + } + } + + /** + * Reports whether any candidate has been accumulated. + * + * @since n.e.x.t + * + * @return bool True if there is at least one candidate to build. + */ + public function hasCandidates(): bool + { + return $this->candidates !== []; + } + + /** + * Assembles the accumulated state into a result. + * + * @since n.e.x.t + * + * @return GenerativeAiResult The assembled result. + * @throws RuntimeException If no candidates were accumulated. + */ + public function build(): GenerativeAiResult + { + if ($this->candidates === []) { + throw new RuntimeException('The stream produced no candidates.'); + } + + $indices = array_keys($this->candidates); + sort($indices); + + $candidates = []; + foreach ($indices as $index) { + $candidates[] = $this->buildCandidate($index); + } + + return new GenerativeAiResult( + $this->id ?? '', + $candidates, + $this->tokenUsage ?? new TokenUsage(0, 0, 0), + $this->providerMetadata, + $this->modelMetadata, + $this->additionalData + ); + } + + /** + * Folds a candidate delta into the per-candidate state. + * + * @since n.e.x.t + * + * @param CandidateDelta $delta The candidate delta to fold in. + * @return void + */ + private function addCandidateDelta(CandidateDelta $delta): void + { + $index = $delta->getIndex(); + $this->candidates[$index] = true; + + $finishReason = $delta->getFinishReason(); + if ($finishReason !== null) { + $this->finishReasons[$index] = $finishReason; + } + + foreach ($delta->getParts() as $part) { + $this->addPart($index, $part); + } + + foreach ($delta->getToolCallDeltas() as $toolCallDelta) { + $this->addToolCallDelta($index, $toolCallDelta); + } + } + + /** + * Folds a content part into the candidate state. + * + * @since n.e.x.t + * + * @param int $index The candidate index. + * @param MessagePart $part The part to fold in. + * @return void + */ + private function addPart(int $index, MessagePart $part): void + { + $text = $part->getText(); + if ($text === null) { + $this->otherParts[$index][] = $part; + return; + } + + $channel = $part->getChannel()->value; + if (!isset($this->text[$index][$channel])) { + $this->text[$index][$channel] = ''; + } + $this->text[$index][$channel] .= $text; + + $signature = $part->getThoughtSignature(); + if ($signature !== null) { + $this->thoughtSignatures[$index][$channel] = $signature; + } + } + + /** + * Stores a tool call fragment into the candidate's tool call slots. + * + * @since n.e.x.t + * + * @param int $index The candidate index. + * @param ToolCallDelta $delta The tool call fragment. + * @return void + */ + private function addToolCallDelta(int $index, ToolCallDelta $delta): void + { + $slot = $delta->getIndex() ?? 0; + + if (!isset($this->toolCalls[$index][$slot])) { + $this->toolCalls[$index][$slot] = ['id' => null, 'name' => null, 'args' => '']; + } + + $id = $delta->getId(); + if ($id !== null && $this->toolCalls[$index][$slot]['id'] === null) { + $this->toolCalls[$index][$slot]['id'] = $id; + } + + $name = $delta->getFunctionName(); + if ($name !== null && $this->toolCalls[$index][$slot]['name'] === null) { + $this->toolCalls[$index][$slot]['name'] = $name; + } + + $this->toolCalls[$index][$slot]['args'] .= $delta->getArgumentsFragment(); + } + + /** + * Builds a single candidate from its accumulated state. + * + * @since n.e.x.t + * + * @param int $index The candidate index. + * @return Candidate The assembled candidate. + */ + private function buildCandidate(int $index): Candidate + { + $parts = []; + + foreach ($this->orderedChannels($index) as $channel) { + $parts[] = new MessagePart( + $this->text[$index][$channel], + MessagePartChannelEnum::from($channel), + $this->thoughtSignatures[$index][$channel] ?? null + ); + } + + foreach ($this->otherParts[$index] ?? [] as $part) { + $parts[] = $part; + } + + // Tool calls last, matching the non-streamed response part order. + foreach ($this->buildToolCallParts($index) as $part) { + $parts[] = $part; + } + + $message = new Message(MessageRoleEnum::model(), $parts); + $finishReason = $this->finishReasons[$index] ?? FinishReasonEnum::stop(); + + return new Candidate($message, $finishReason); + } + + /** + * Returns the candidate's text channels in canonical order, with any unknown channel last. + * + * @since n.e.x.t + * + * @param int $index The candidate index. + * @return list The present channels, ordered. + */ + private function orderedChannels(int $index): array + { + $present = array_keys($this->text[$index] ?? []); + + $ordered = []; + foreach (self::CANONICAL_CHANNEL_ORDER as $channel) { + if (in_array($channel, $present, true)) { + $ordered[] = $channel; + } + } + + foreach ($present as $channel) { + if (!in_array($channel, self::CANONICAL_CHANNEL_ORDER, true)) { + // Add any unknown channel last, in arrival order in case the provider sends multiple unknown channels. + // @codeCoverageIgnoreStart + $ordered[] = $channel; + // @codeCoverageIgnoreEnd + } + } + + return $ordered; + } + + /** + * Assembles the stored tool call slots for a candidate into message parts. + * + * @since n.e.x.t + * + * @param int $index The candidate index. + * @return list The assembled function call parts, in slot order. + */ + private function buildToolCallParts(int $index): array + { + if (!isset($this->toolCalls[$index])) { + return []; + } + + $slots = $this->toolCalls[$index]; + ksort($slots); + + $parts = []; + foreach ($slots as $slot) { + // A function call needs at least an id or a name; skip a slot that received + // neither (a malformed stream) rather than failing the whole result. + if ($slot['id'] === null && $slot['name'] === null) { + continue; + } + + $parts[] = new MessagePart( + new FunctionCall($slot['id'], $slot['name'], $this->decodeToolCallArgs($slot['args'])) + ); + } + + return $parts; + } + + /** + * Decodes accumulated tool call arguments. + * + * @since n.e.x.t + * + * @param string $arguments The accumulated arguments string. + * @return mixed The decoded arguments, the raw string on failure, or null when empty. + */ + private function decodeToolCallArgs(string $arguments) + { + if ($arguments === '') { + return null; + } + + $decoded = json_decode($arguments, true); + + return json_last_error() === JSON_ERROR_NONE ? $decoded : $arguments; + } +} diff --git a/src/Results/StreamedGenerativeAiResult.php b/src/Results/StreamedGenerativeAiResult.php new file mode 100644 index 00000000..03bbdc2a --- /dev/null +++ b/src/Results/StreamedGenerativeAiResult.php @@ -0,0 +1,269 @@ + + */ +final class StreamedGenerativeAiResult implements IteratorAggregate +{ + /** + * @var Iterator The source chunk stream. + */ + private Iterator $chunks; + + private ChunkAccumulator $accumulator; + + /** + * @var list Callbacks run once when consumption begins. + */ + private array $startCallbacks = []; + + /** + * @var list Callbacks run once when the result is assembled. + */ + private array $completionCallbacks = []; + + /** + * @var list Callbacks run once when consumption fails. + */ + private array $errorCallbacks = []; + + /** + * @var bool Whether the source stream has been started. + */ + private bool $started = false; + + /** + * @var bool Whether the source stream has been fully read. + */ + private bool $finished = false; + + /** + * @var GenerativeAiResult|null The assembled result, once built. + */ + private ?GenerativeAiResult $result = null; + + /** + * @var Throwable|null The error that ended consumption, once it failed. + */ + private ?Throwable $error = null; + + /** + * Constructor. + * + * @since n.e.x.t + * + * @param Iterator $chunks The source chunk stream. + * @param ProviderMetadata $providerMetadata Provider metadata for the assembled result. + * @param ModelMetadata $modelMetadata Model metadata for the assembled result. + */ + public function __construct(Iterator $chunks, ProviderMetadata $providerMetadata, ModelMetadata $modelMetadata) + { + $this->chunks = $chunks; + $this->accumulator = new ChunkAccumulator($providerMetadata, $modelMetadata); + } + + /** + * Registers a callback to run once, when consumption begins. + * + * @since n.e.x.t + * + * @param callable(): void $callback The callback. + * @return self + */ + public function onStart(callable $callback): self + { + $this->startCallbacks[] = $callback; + + return $this; + } + + /** + * Registers a callback to run once, when the final result is first assembled. + * + * @since n.e.x.t + * + * @param callable(GenerativeAiResult): void $callback Receives the assembled result. + * @return self + */ + public function onComplete(callable $callback): self + { + $this->completionCallbacks[] = $callback; + + return $this; + } + + /** + * Registers a callback to run once, when consumption fails. + * + * @since n.e.x.t + * + * @param callable(Throwable): void $callback Receives the error. + * @return self + */ + public function onError(callable $callback): self + { + $this->errorCallbacks[] = $callback; + + return $this; + } + + /** + * Yields each chunk as it is read, folding it into the accumulated state. + * + * @since n.e.x.t + * + * @return Generator The chunks, in order. + * + * @throws RuntimeException If the source stream has already been consumed. + */ + public function getIterator(): Generator + { + if ($this->started) { + throw new RuntimeException( + 'This streamed result has already been consumed; the stream can be read only once.' + ); + } + + try { + while (true) { + $chunk = $this->pull(); + if ($chunk === null) { + break; + } + yield $chunk; + } + + $this->finalize(); + } catch (Throwable $e) { + $this->fail($e); + throw $e; + } + } + + /** + * Returns the complete result, draining any unread chunks first. + * + * @since n.e.x.t + * + * @return GenerativeAiResult The assembled result. + * + * @throws Throwable The original stream error, if consumption failed. + * @throws RuntimeException If the stream produced no candidates. + */ + public function getFinalResult(): GenerativeAiResult + { + if ($this->error !== null) { + throw $this->error; + } + + if ($this->result === null) { + try { + while ($this->pull() !== null) { + // Drain any remaining chunks so the result is complete. + } + + $this->finalize(); + } catch (Throwable $e) { + $this->fail($e); + throw $e; + } + } + + if ($this->result === null) { + $error = new RuntimeException('The stream produced no candidates.'); + $this->fail($error); + throw $error; + } + + return $this->result; + } + + /** + * Assembles the result and runs the completion callbacks. + * + * @since n.e.x.t + * + * @return void + */ + private function finalize(): void + { + if (!$this->accumulator->hasCandidates()) { + return; + } + + $this->result = $this->accumulator->build(); + + foreach ($this->completionCallbacks as $callback) { + $callback($this->result); + } + } + + /** + * Handles a failure in the stream, storing the error and running the error callbacks. + * + * @since n.e.x.t + * + * @param Throwable $error The error that ended consumption. + * @return void + */ + private function fail(Throwable $error): void + { + $this->error = $error; + + foreach ($this->errorCallbacks as $callback) { + $callback($error); + } + } + + /** + * Reads the next chunk from the source and folds it into the accumulated state. + * + * @since n.e.x.t + * + * @return GenerativeAiResultChunk|null The next chunk, or null when the stream is exhausted. + */ + private function pull(): ?GenerativeAiResultChunk + { + if ($this->finished) { + return null; + } + + if (!$this->started) { + $this->started = true; + foreach ($this->startCallbacks as $callback) { + $callback(); + } + $this->chunks->rewind(); + } else { + $this->chunks->next(); + } + + if (!$this->chunks->valid()) { + $this->finished = true; + return null; + } + + $chunk = $this->chunks->current(); + $this->accumulator->add($chunk); + + return $chunk; + } +} diff --git a/src/Results/ValueObjects/CandidateDelta.php b/src/Results/ValueObjects/CandidateDelta.php new file mode 100644 index 00000000..cbaa6101 --- /dev/null +++ b/src/Results/ValueObjects/CandidateDelta.php @@ -0,0 +1,153 @@ + The partial content parts for this candidate. + */ + private array $parts; + + /** + * @var FinishReasonEnum|null The finish reason, when this delta reports it. + */ + private ?FinishReasonEnum $finishReason; + + /** + * @var list The partial tool calls for this candidate. + */ + private array $toolCallDeltas; + + /** + * Constructor. + * + * @since n.e.x.t + * + * @param int $index The candidate index this delta contributes to. + * @param list $parts The partial content parts. + * @param FinishReasonEnum|null $finishReason The finish reason, when reported. + * @param list $toolCallDeltas The partial tool calls. + */ + public function __construct( + int $index, + array $parts = [], + ?FinishReasonEnum $finishReason = null, + array $toolCallDeltas = [] + ) { + $this->index = $index; + $this->parts = $parts; + $this->finishReason = $finishReason; + $this->toolCallDeltas = $toolCallDeltas; + } + + /** + * Gets the candidate index this delta contributes to. + * + * @since n.e.x.t + * + * @return int The candidate index. + */ + public function getIndex(): int + { + return $this->index; + } + + /** + * Gets the partial content parts. + * + * @since n.e.x.t + * + * @return list The content parts. + */ + public function getParts(): array + { + return $this->parts; + } + + /** + * Gets the finish reason. + * + * @since n.e.x.t + * + * @return FinishReasonEnum|null The finish reason, or null when not reported by this delta. + */ + public function getFinishReason(): ?FinishReasonEnum + { + return $this->finishReason; + } + + /** + * Gets the partial tool calls. + * + * @since n.e.x.t + * + * @return list The tool call fragments, possibly empty. + */ + public function getToolCallDeltas(): array + { + return $this->toolCallDeltas; + } + + /** + * Gets the delta text of this candidate's content channel. + * + * @since n.e.x.t + * + * @return string The content text delta, or an empty string when this delta carries none. + */ + public function getDeltaText(): string + { + return $this->deltaTextForChannel(MessagePartChannelEnum::content()); + } + + /** + * Gets the delta text of this candidate's reasoning (thought) channel. + * + * @since n.e.x.t + * + * @return string The reasoning text delta, or an empty string when this delta carries none. + */ + public function getReasoningDeltaText(): string + { + return $this->deltaTextForChannel(MessagePartChannelEnum::thought()); + } + + /** + * Concatenates the delta text of this candidate's parts on the given channel. + * + * @since n.e.x.t + * + * @param MessagePartChannelEnum $channel The channel to read. + * @return string The concatenated delta text, or an empty string when there is none. + */ + private function deltaTextForChannel(MessagePartChannelEnum $channel): string + { + $text = ''; + foreach ($this->parts as $part) { + if ($part->getChannel()->is($channel) && $part->getText() !== null) { + $text .= $part->getText(); + } + } + + return $text; + } +} diff --git a/src/Results/ValueObjects/GenerativeAiResultChunk.php b/src/Results/ValueObjects/GenerativeAiResultChunk.php new file mode 100644 index 00000000..6ce6b820 --- /dev/null +++ b/src/Results/ValueObjects/GenerativeAiResultChunk.php @@ -0,0 +1,164 @@ + Result-level provider metadata carried by this chunk. + */ + private array $additionalData; + + /** + * @var list The per-candidate deltas carried by this chunk. + */ + private array $candidateDeltas; + + /** + * Constructor. + * + * @since n.e.x.t + * + * @param string|null $id The result id, when reported. + * @param TokenUsage|null $tokenUsage The token usage, when reported. + * @param array $additionalData Result-level provider metadata. + * @param list $candidateDeltas The per-candidate deltas. + */ + public function __construct( + ?string $id = null, + ?TokenUsage $tokenUsage = null, + array $additionalData = [], + array $candidateDeltas = [] + ) { + $this->id = $id; + $this->tokenUsage = $tokenUsage; + $this->additionalData = $additionalData; + $this->candidateDeltas = $candidateDeltas; + } + + /** + * Gets the result id. + * + * @since n.e.x.t + * + * @return string|null The id, or null when not reported by this chunk. + */ + public function getId(): ?string + { + return $this->id; + } + + /** + * Gets the token usage. + * + * @since n.e.x.t + * + * @return TokenUsage|null The token usage, or null when not reported by this chunk. + */ + public function getTokenUsage(): ?TokenUsage + { + return $this->tokenUsage; + } + + /** + * Gets the result-level provider metadata carried by this chunk. + * + * @since n.e.x.t + * + * @return array The provider metadata, possibly empty. + */ + public function getAdditionalData(): array + { + return $this->additionalData; + } + + /** + * Gets the per-candidate deltas carried by this chunk. + * + * @since n.e.x.t + * + * @return list The candidate deltas, possibly empty (metadata-only event). + */ + public function getCandidateDeltas(): array + { + return $this->candidateDeltas; + } + + /** + * Gets the content text delta for a single candidate. + * + * @since n.e.x.t + * + * @param int $candidateIndex The candidate index to read. + * @return string The content text delta, or an empty string when the candidate carries none. + */ + public function getDeltaText(int $candidateIndex = 0): string + { + foreach ($this->candidateDeltas as $delta) { + if ($delta->getIndex() === $candidateIndex) { + return $delta->getDeltaText(); + } + } + + return ''; + } + + /** + * Gets the reasoning (thought) text delta for a single candidate. + * + * @since n.e.x.t + * + * @param int $candidateIndex The candidate index to read. + * @return string The reasoning text delta, or an empty string when the candidate carries none. + */ + public function getReasoningDeltaText(int $candidateIndex = 0): string + { + foreach ($this->candidateDeltas as $delta) { + if ($delta->getIndex() === $candidateIndex) { + return $delta->getReasoningDeltaText(); + } + } + + return ''; + } + + /** + * Gets the tool call fragments carried by this chunk. + * + * @since n.e.x.t + * + * @return list The tool call fragments, possibly empty. + */ + public function getToolCallDeltas(): array + { + $deltas = []; + foreach ($this->candidateDeltas as $candidateDelta) { + foreach ($candidateDelta->getToolCallDeltas() as $toolCallDelta) { + $deltas[] = $toolCallDelta; + } + } + + return $deltas; + } +} diff --git a/src/Results/ValueObjects/ToolCallDelta.php b/src/Results/ValueObjects/ToolCallDelta.php new file mode 100644 index 00000000..58a60b39 --- /dev/null +++ b/src/Results/ValueObjects/ToolCallDelta.php @@ -0,0 +1,105 @@ +index = $index; + $this->id = $id; + $this->functionName = $functionName; + $this->argumentsFragment = $argumentsFragment; + } + + /** + * Gets the tool call slot index this fragment contributes to. + * + * @since n.e.x.t + * + * @return int|null The slot index, or null when not reported. + */ + public function getIndex(): ?int + { + return $this->index; + } + + /** + * Gets the tool call id. + * + * @since n.e.x.t + * + * @return string|null The id, or null when not reported by this fragment. + */ + public function getId(): ?string + { + return $this->id; + } + + /** + * Gets the function name. + * + * @since n.e.x.t + * + * @return string|null The function name, or null when not reported by this fragment. + */ + public function getFunctionName(): ?string + { + return $this->functionName; + } + + /** + * Gets the partial function arguments carried by this fragment. + * + * @since n.e.x.t + * + * @return string The arguments fragment (may be empty). + */ + public function getArgumentsFragment(): string + { + return $this->argumentsFragment; + } +} diff --git a/tests/mocks/ChunkStream.php b/tests/mocks/ChunkStream.php new file mode 100644 index 00000000..8e279fdb --- /dev/null +++ b/tests/mocks/ChunkStream.php @@ -0,0 +1,146 @@ + The chunks to return, in order. + */ + private array $chunks; + + /** + * @var int Index of the next chunk to return. + */ + private int $position = 0; + + /** + * @var int Number of read() calls that returned a chunk. + */ + private int $readCount = 0; + + /** + * @var bool Whether close() has been called. + */ + private bool $closed = false; + + /** + * @param list $chunks The chunks to return, in order. + */ + public function __construct(array $chunks) + { + $this->chunks = array_values($chunks); + } + + /** + * @return int Number of read() calls that returned a chunk. + */ + public function getReadCount(): int + { + return $this->readCount; + } + + /** + * @return bool Whether the stream was closed. + */ + public function isClosed(): bool + { + return $this->closed; + } + + public function eof(): bool + { + return $this->position >= count($this->chunks); + } + + public function read(int $length): string + { + if ($this->eof()) { + return ''; + } + + $this->readCount++; + + return $this->chunks[$this->position++]; + } + + public function close(): void + { + $this->closed = true; + } + + public function __toString(): string + { + return implode('', array_slice($this->chunks, $this->position)); + } + + public function detach() + { + return null; + } + + public function getSize(): ?int + { + return null; + } + + public function tell(): int + { + return $this->position; + } + + public function isSeekable(): bool + { + return false; + } + + public function seek(int $offset, int $whence = SEEK_SET): void + { + throw new RuntimeException('Not seekable.'); + } + + public function rewind(): void + { + throw new RuntimeException('Not seekable.'); + } + + public function isWritable(): bool + { + return false; + } + + public function write(string $string): int + { + throw new RuntimeException('Not writable.'); + } + + public function isReadable(): bool + { + return true; + } + + public function getContents(): string + { + $contents = implode('', array_slice($this->chunks, $this->position)); + $this->position = count($this->chunks); + + return $contents; + } + + /** + * @param string|null $key The metadata key. + * @return array|null The metadata. + */ + public function getMetadata(?string $key = null) + { + return $key === null ? [] : null; + } +} diff --git a/tests/mocks/FailingChunkStream.php b/tests/mocks/FailingChunkStream.php new file mode 100644 index 00000000..a931791b --- /dev/null +++ b/tests/mocks/FailingChunkStream.php @@ -0,0 +1,48 @@ + $chunks The chunks to deliver before failing. + * @param string $errorMessage The message of the error thrown afterwards. + */ + public function __construct(array $chunks, string $errorMessage = 'Connection reset by peer') + { + parent::__construct($chunks); + $this->errorMessage = $errorMessage; + } + + /** + * Never reports end-of-stream, so the reader keeps reading until the failure. + */ + public function eof(): bool + { + return false; + } + + /** + * Returns the next chunk, or throws once the chunks are exhausted. + */ + public function read(int $length): string + { + if (parent::eof()) { + throw new RuntimeException($this->errorMessage); + } + + return parent::read($length); + } +} diff --git a/tests/traits/MockModelCreationTrait.php b/tests/traits/MockModelCreationTrait.php index d330c518..4a6c9a3d 100644 --- a/tests/traits/MockModelCreationTrait.php +++ b/tests/traits/MockModelCreationTrait.php @@ -4,6 +4,7 @@ namespace WordPress\AiClient\Tests\traits; +use ArrayIterator; use WordPress\AiClient\Messages\DTO\MessagePart; use WordPress\AiClient\Messages\DTO\ModelMessage; use WordPress\AiClient\Providers\DTO\ProviderMetadata; @@ -13,6 +14,7 @@ use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; +use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\StreamingTextGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Providers\Models\VideoGeneration\Contracts\VideoGenerationModelInterface; use WordPress\AiClient\Providers\ProviderRegistry; @@ -20,6 +22,9 @@ use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Results\StreamedGenerativeAiResult; +use WordPress\AiClient\Results\ValueObjects\CandidateDelta; +use WordPress\AiClient\Results\ValueObjects\GenerativeAiResultChunk; use WordPress\AiClient\Tests\mocks\MockProvider; /** @@ -368,4 +373,102 @@ protected function createMockUnsupportedModel(string $modelId = 'unsupported-mod return $mockModel; } + + /** + * Creates a chunk carrying a content text delta on candidate 0. + * + * @param string $text The content delta. + * @param FinishReasonEnum|null $finishReason Optional finish reason for the candidate. + * @return GenerativeAiResultChunk + */ + protected function createStreamingTextChunk( + string $text, + ?FinishReasonEnum $finishReason = null + ): GenerativeAiResultChunk { + return new GenerativeAiResultChunk(null, null, [], [ + new CandidateDelta(0, [new MessagePart($text)], $finishReason), + ]); + } + + /** + * Creates a mock model that streams the given chunks. + * + * @param iterable $chunks The chunks to stream. + * @param ModelMetadata|null $metadata Optional metadata (uses default if not provided). + * @return ModelInterface&StreamingTextGenerationModelInterface The mock model. + */ + protected function createMockStreamingTextGenerationModel( + iterable $chunks, + ?ModelMetadata $metadata = null + ): ModelInterface { + $metadata = $metadata ?? $this->createTestTextModelMetadata(); + + $providerMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + return new class ( + $metadata, + $providerMetadata, + $chunks + ) implements ModelInterface, TextGenerationModelInterface, StreamingTextGenerationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + /** @var iterable */ + private iterable $chunks; + private ModelConfig $config; + + /** + * @param iterable $chunks + */ + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + iterable $chunks + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->chunks = $chunks; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateTextResult(array $prompt): GenerativeAiResult + { + throw new \RuntimeException('Non-streaming generation is not exercised by streaming tests.'); + } + + public function streamGenerateTextResult(array $prompt): StreamedGenerativeAiResult + { + $source = is_array($this->chunks) ? new ArrayIterator($this->chunks) : $this->chunks; + + return new StreamedGenerativeAiResult( + $source, + $this->providerMetadata, + $this->metadata + ); + } + }; + } } diff --git a/tests/unit/AiClientTest.php b/tests/unit/AiClientTest.php index c573c756..3cbc3563 100644 --- a/tests/unit/AiClientTest.php +++ b/tests/unit/AiClientTest.php @@ -12,6 +12,8 @@ use WordPress\AiClient\Providers\Contracts\ProviderAvailabilityInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Results\StreamedGenerativeAiResult; use WordPress\AiClient\Tests\mocks\MockProvider; use WordPress\AiClient\Tests\traits\MockModelCreationTrait; @@ -798,4 +800,30 @@ public function testGetConfiguredPromptBuilderHelperIntegration(): void $this->expectExceptionMessageMatches('/No models found that support/'); AiClient::generateResult($prompt, null, $this->createMockEmptyRegistry()); } + + /** + * Tests that streamGenerateTextResult delegates to the prompt builder and returns a handle. + */ + public function testStreamGenerateTextResultWithStringAndModel(): void + { + $model = $this->createMockStreamingTextGenerationModel([ + $this->createStreamingTextChunk('Hello', FinishReasonEnum::stop()), + ]); + $registry = $this->createRegistryWithMockProvider(); + + $handle = AiClient::streamGenerateTextResult('Generate text', $model, $registry); + + $this->assertInstanceOf(StreamedGenerativeAiResult::class, $handle); + $this->assertSame('Hello', $handle->getFinalResult()->toText()); + } + + /** + * Tests that streamGenerateTextResult validates the model-or-config parameter. + */ + public function testStreamGenerateTextResultWithInvalidModel(): void + { + $this->expectException(\InvalidArgumentException::class); + $this->expectExceptionMessage('Parameter must be a ModelInterface'); + AiClient::streamGenerateTextResult('Generate text', 'invalid', $this->createRegistryWithMockProvider()); + } } diff --git a/tests/unit/Builders/PromptBuilderEventDispatchingTest.php b/tests/unit/Builders/PromptBuilderEventDispatchingTest.php index 092a9330..a281ca40 100644 --- a/tests/unit/Builders/PromptBuilderEventDispatchingTest.php +++ b/tests/unit/Builders/PromptBuilderEventDispatchingTest.php @@ -5,11 +5,16 @@ namespace WordPress\AiClient\Tests\unit\Builders; use PHPUnit\Framework\TestCase; +use RuntimeException; use WordPress\AiClient\Builders\PromptBuilder; use WordPress\AiClient\Events\AfterGenerateResultEvent; use WordPress\AiClient\Events\BeforeGenerateResultEvent; +use WordPress\AiClient\Events\GenerateResultErrorEvent; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Results\DTO\TokenUsage; +use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Results\ValueObjects\GenerativeAiResultChunk; use WordPress\AiClient\Tests\mocks\MockEventDispatcher; use WordPress\AiClient\Tests\mocks\MockProvider; use WordPress\AiClient\Tests\traits\MockModelCreationTrait; @@ -156,4 +161,158 @@ public function testEventsDispatchedInCorrectOrder(): void $this->assertInstanceOf(BeforeGenerateResultEvent::class, $events[0]); $this->assertInstanceOf(AfterGenerateResultEvent::class, $events[1]); } + + /** + * Creates a builder wired with the injected dispatcher and a model that streams the given chunks. + * + * @return PromptBuilder + */ + private function createStreamingBuilderWithDispatcher(): PromptBuilder + { + $model = $this->createMockStreamingTextGenerationModel([ + $this->createStreamingTextChunk('Hel'), + $this->createStreamingTextChunk('lo', FinishReasonEnum::stop()), + new GenerativeAiResultChunk(null, new TokenUsage(3, 5, 8), [], []), + ]); + + $builder = new PromptBuilder($this->registry, 'Hello', $this->dispatcher); + $builder->usingModel($model); + + return $builder; + } + + /** + * Tests that the Before event is not dispatched until the stream is consumed. + * + * @return void + */ + public function testStreamingDoesNotDispatchBeforeEventUntilConsumed(): void + { + $handle = $this->createStreamingBuilderWithDispatcher()->streamGenerateTextResult(); + + $this->assertCount(0, $this->dispatcher->getDispatchedEventsOfType(BeforeGenerateResultEvent::class)); + + $handle->getFinalResult(); + + $this->assertCount(1, $this->dispatcher->getDispatchedEventsOfType(BeforeGenerateResultEvent::class)); + } + + /** + * Tests that a stream failure dispatches the error event and no After event. + * + * @return void + */ + public function testStreamingDispatchesErrorEventOnStreamFailure(): void + { + $chunk = $this->createStreamingTextChunk('Hello'); + $source = (function () use ($chunk) { + yield $chunk; + throw new RuntimeException('stream failed'); + })(); + $model = $this->createMockStreamingTextGenerationModel($source); + + $builder = new PromptBuilder($this->registry, 'Hello', $this->dispatcher); + $builder->usingModel($model); + + try { + foreach ($builder->streamGenerateTextResult() as $streamChunk) { + } + } catch (RuntimeException $e) { + } + + $this->assertCount(1, $this->dispatcher->getDispatchedEventsOfType(BeforeGenerateResultEvent::class)); + $this->assertCount(0, $this->dispatcher->getDispatchedEventsOfType(AfterGenerateResultEvent::class)); + + $errorEvents = $this->dispatcher->getDispatchedEventsOfType(GenerateResultErrorEvent::class); + $this->assertCount(1, $errorEvents); + $this->assertSame('stream failed', $errorEvents[0]->getError()->getMessage()); + } + + /** + * Tests that the After event fires once with the assembled result. + * + * @return void + */ + public function testStreamingDispatchesAfterEventOnceWithAssembledResult(): void + { + $result = $this->createStreamingBuilderWithDispatcher() + ->streamGenerateTextResult() + ->getFinalResult(); + + $afterEvents = $this->dispatcher->getDispatchedEventsOfType(AfterGenerateResultEvent::class); + $this->assertCount(1, $afterEvents); + + $event = $afterEvents[0]; + $this->assertSame($result, $event->getResult()); + $this->assertSame('Hello', $event->getResult()->toText()); + $this->assertSame(8, $event->getResult()->getTokenUsage()->getTotalTokens()); + $this->assertEquals(CapabilityEnum::textGeneration(), $event->getCapability()); + $this->assertCount(1, $event->getMessages()); + } + + /** + * Tests that the After event is not dispatched on an early break. + * + * @return void + */ + public function testStreamingDoesNotDispatchAfterEventOnEarlyBreak(): void + { + foreach ($this->createStreamingBuilderWithDispatcher()->streamGenerateTextResult() as $chunk) { + break; + } + + $this->assertCount(1, $this->dispatcher->getDispatchedEventsOfType(BeforeGenerateResultEvent::class)); + $this->assertCount(0, $this->dispatcher->getDispatchedEventsOfType(AfterGenerateResultEvent::class)); + } + + /** + * Tests that the After event fires only once across iteration and getFinalResult(). + * + * @return void + */ + public function testStreamingDispatchesAfterEventOnlyOnce(): void + { + $handle = $this->createStreamingBuilderWithDispatcher()->streamGenerateTextResult(); + + foreach ($handle as $chunk) { + } + $handle->getFinalResult(); + + $this->assertCount(1, $this->dispatcher->getDispatchedEventsOfType(AfterGenerateResultEvent::class)); + } + + /** + * Tests that streaming dispatches nothing when no dispatcher is set. + * + * @return void + */ + public function testStreamingDispatchesNoEventsWithoutDispatcher(): void + { + $model = $this->createMockStreamingTextGenerationModel([ + $this->createStreamingTextChunk('Hello', FinishReasonEnum::stop()), + ]); + + $builder = new PromptBuilder($this->registry, 'Hello'); + $builder->usingModel($model); + + $result = $builder->streamGenerateTextResult()->getFinalResult(); + + $this->assertSame('Hello', $result->toText()); + $this->assertCount(0, $this->dispatcher->getDispatchedEvents()); + } + + /** + * Tests that streaming dispatches Before then After. + * + * @return void + */ + public function testStreamingDispatchesEventsInOrder(): void + { + $this->createStreamingBuilderWithDispatcher()->streamGenerateTextResult()->getFinalResult(); + + $events = $this->dispatcher->getDispatchedEvents(); + $this->assertCount(2, $events); + $this->assertInstanceOf(BeforeGenerateResultEvent::class, $events[0]); + $this->assertInstanceOf(AfterGenerateResultEvent::class, $events[1]); + } } diff --git a/tests/unit/Builders/PromptBuilderTest.php b/tests/unit/Builders/PromptBuilderTest.php index ce68a223..0f35cdb5 100644 --- a/tests/unit/Builders/PromptBuilderTest.php +++ b/tests/unit/Builders/PromptBuilderTest.php @@ -36,6 +36,8 @@ use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Results\StreamedGenerativeAiResult; +use WordPress\AiClient\Results\ValueObjects\GenerativeAiResultChunk; use WordPress\AiClient\Tests\traits\MockModelCreationTrait; use WordPress\AiClient\Tools\DTO\FunctionDeclaration; use WordPress\AiClient\Tools\DTO\FunctionResponse; @@ -4143,4 +4145,68 @@ public function testUsingStopSequencesSetsProperty(): void $this->assertEquals(['STOP', 'END'], $config->getStopSequences()); } + + /** + * Tests streamGenerateTextResult returns a StreamedGenerativeAiResult handle and yields the final result. + * + * @return void + */ + public function testStreamGenerateTextResultReturnsStreamedHandle(): void + { + $model = $this->createMockStreamingTextGenerationModel([ + $this->createStreamingTextChunk('Hel'), + $this->createStreamingTextChunk('lo', FinishReasonEnum::stop()), + new GenerativeAiResultChunk(null, new TokenUsage(3, 5, 8), [], []), + ]); + + $builder = new PromptBuilder($this->registry, 'Hello'); + $builder->usingModel($model); + + $handle = $builder->streamGenerateTextResult(); + + $this->assertInstanceOf(StreamedGenerativeAiResult::class, $handle); + $result = $handle->getFinalResult(); + $this->assertSame('Hello', $result->toText()); + $this->assertSame(8, $result->getTokenUsage()->getTotalTokens()); + } + + /** + * Tests streamGenerateTextResult throws an exception when the model does not support streaming. + * + * @return void + */ + public function testStreamGenerateTextResultThrowsWhenModelDoesNotSupportStreaming(): void + { + // A plain text-generation model does not implement StreamingTextGenerationModelInterface. + $model = $this->createMockTextGenerationModel($this->createTestResult()); + + $builder = new PromptBuilder($this->registry, 'Hello'); + $builder->usingModel($model); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('does not support streaming text generation'); + $builder->streamGenerateTextResult(); + } + + /** + * Tests streamGenerateText yields the text deltas in order and skips chunks without text. + * + * @return void + */ + public function testStreamGenerateTextYieldsTextDeltasAndSkipsEmpty(): void + { + $model = $this->createMockStreamingTextGenerationModel([ + $this->createStreamingTextChunk('Hel'), + // A metadata-only chunk carries no text and must not be yielded. + new GenerativeAiResultChunk(null, new TokenUsage(1, 1, 2), [], []), + $this->createStreamingTextChunk('lo', FinishReasonEnum::stop()), + ]); + + $builder = new PromptBuilder($this->registry, 'Hello'); + $builder->usingModel($model); + + $deltas = iterator_to_array($builder->streamGenerateText(), false); + + $this->assertSame(['Hel', 'lo'], $deltas); + } } diff --git a/tests/unit/Events/GenerateResultErrorEventTest.php b/tests/unit/Events/GenerateResultErrorEventTest.php new file mode 100644 index 00000000..fdc96e0c --- /dev/null +++ b/tests/unit/Events/GenerateResultErrorEventTest.php @@ -0,0 +1,75 @@ +createMockTextGenerationModel($this->createTestResult()); + $capability = CapabilityEnum::textGeneration(); + $error = new RuntimeException('stream failed'); + + $event = new GenerateResultErrorEvent($messages, $model, $capability, $error); + + $this->assertSame($messages, $event->getMessages()); + $this->assertSame($model, $event->getModel()); + $this->assertSame($capability, $event->getCapability()); + $this->assertSame($error, $event->getError()); + } + + /** + * Tests event construction with null capability. + * + * @return void + */ + public function testConstructionWithNullCapability(): void + { + $messages = [new UserMessage([new MessagePart('Hello')])]; + $model = $this->createMockTextGenerationModel($this->createTestResult()); + + $event = new GenerateResultErrorEvent($messages, $model, null, new RuntimeException('boom')); + + $this->assertNull($event->getCapability()); + } + + /** + * Tests that cloning copies messages but keeps the model and error instances. + * + * @return void + */ + public function testCloneClonesMessagesOnly(): void + { + $messages = [new UserMessage([new MessagePart('Hello')])]; + $model = $this->createMockTextGenerationModel($this->createTestResult()); + $error = new RuntimeException('boom'); + + $original = new GenerateResultErrorEvent($messages, $model, CapabilityEnum::textGeneration(), $error); + $cloned = clone $original; + + $this->assertNotSame($original->getMessages()[0], $cloned->getMessages()[0]); + $this->assertSame($original->getModel(), $cloned->getModel()); + $this->assertSame($original->getError(), $cloned->getError()); + } +} diff --git a/tests/unit/Providers/Http/DTO/RequestOptionsTest.php b/tests/unit/Providers/Http/DTO/RequestOptionsTest.php index 12f77446..b30808dc 100644 --- a/tests/unit/Providers/Http/DTO/RequestOptionsTest.php +++ b/tests/unit/Providers/Http/DTO/RequestOptionsTest.php @@ -94,4 +94,68 @@ public function testGetJsonSchemaDefinesNullableMaxRedirects(): void $this->assertSame(['integer', 'null'], $schema['properties'][RequestOptions::KEY_MAX_REDIRECTS]['type']); } + + /** + * Tests that setStream toggles the stream flag. + * + * @return void + */ + public function testSetStreamTogglesStreamFlag(): void + { + $options = new RequestOptions(); + $options->setStream(true); + $this->assertTrue($options->isStream()); + + $options->setStream(false); + $this->assertFalse($options->isStream()); + } + + /** + * Tests that the stream flag is null by default. + * + * @return void + */ + public function testStreamIsNullByDefault(): void + { + $this->assertNull((new RequestOptions())->isStream()); + } + + /** + * Tests that toArray includes the stream flag only when it is set. + * + * @return void + */ + public function testToArrayIncludesStreamWhenSet(): void + { + $options = new RequestOptions(); + $this->assertArrayNotHasKey(RequestOptions::KEY_STREAM, $options->toArray()); + + $options->setStream(true); + $this->assertTrue($options->toArray()[RequestOptions::KEY_STREAM]); + } + + /** + * Tests that fromArray reads the stream flag. + * + * @return void + */ + public function testFromArrayReadsStream(): void + { + $options = RequestOptions::fromArray([RequestOptions::KEY_STREAM => true]); + + $this->assertInstanceOf(RequestOptions::class, $options); + $this->assertTrue($options->isStream()); + } + + /** + * Tests that the JSON schema defines a nullable boolean stream flag. + * + * @return void + */ + public function testGetJsonSchemaDefinesNullableStream(): void + { + $schema = RequestOptions::getJsonSchema(); + + $this->assertSame(['boolean', 'null'], $schema['properties'][RequestOptions::KEY_STREAM]['type']); + } } diff --git a/tests/unit/Providers/Http/DTO/ResponseTest.php b/tests/unit/Providers/Http/DTO/ResponseTest.php index c5dd7815..f7ff91f6 100644 --- a/tests/unit/Providers/Http/DTO/ResponseTest.php +++ b/tests/unit/Providers/Http/DTO/ResponseTest.php @@ -4,8 +4,10 @@ namespace WordPress\AiClient\Tests\unit\Providers\Http\DTO; +use GuzzleHttp\Psr7\Utils; use PHPUnit\Framework\TestCase; use WordPress\AiClient\Providers\Http\DTO\Response; +use WordPress\AiClient\Tests\mocks\ChunkStream; /** * @covers \WordPress\AiClient\Providers\Http\DTO\Response @@ -76,4 +78,83 @@ public function testCloneWorksWithNullBody(): void $this->assertEquals(204, $cloned->getStatusCode()); $this->assertTrue($cloned->hasHeader('X-Request-Id')); } + + /** + * Tests that a stream body is returned as-is by getStream. + * + * @return void + */ + public function testGetStreamReturnsTheStreamBody(): void + { + $stream = new ChunkStream(['{"ok":true}']); + $response = new Response(200, [], $stream); + + $this->assertSame($stream, $response->getStream()); + } + + /** + * Tests that a streamed body is read into a string by getBody. + * + * @return void + */ + public function testStreamedBodyIsReadByGetBody(): void + { + $response = new Response(200, [], new ChunkStream(['{"ok":', 'true}'])); + + $this->assertSame('{"ok":true}', $response->getBody()); + } + + /** + * Tests that a streamed JSON body is decoded by getData. + * + * @return void + */ + public function testStreamedBodyIsDecodedByGetData(): void + { + $response = new Response(200, [], new ChunkStream(['{"ok":true}'])); + + $this->assertSame(['ok' => true], $response->getData()); + } + + /** + * Tests that a buffered body is wrapped in a stream by getStream. + * + * @return void + */ + public function testBufferedBodyIsWrappedInStreamByGetStream(): void + { + $response = new Response(200, [], 'hello'); + + $this->assertSame('hello', (string) $response->getStream()); + } + + /** + * Tests that a seekable stream body is rewound before being read. + * + * @return void + */ + public function testSeekableStreamBodyIsRewoundBeforeRead(): void + { + $stream = Utils::streamFor('{"ok":true}'); + $stream->read(2); + + $response = new Response(200, [], $stream); + + $this->assertSame('{"ok":true}', $response->getBody()); + } + + /** + * Tests that toArray serializes the body, reading a streamed body when needed. + * + * @return void + */ + public function testToArraySerializesStreamedBody(): void + { + $response = new Response(200, ['X-Test' => 'value'], new ChunkStream(['streamed body'])); + + $array = $response->toArray(); + + $this->assertSame(200, $array[Response::KEY_STATUS_CODE]); + $this->assertSame('streamed body', $array[Response::KEY_BODY]); + } } diff --git a/tests/unit/Providers/Http/Exception/ResponseExceptionTest.php b/tests/unit/Providers/Http/Exception/ResponseExceptionTest.php new file mode 100644 index 00000000..f154c96d --- /dev/null +++ b/tests/unit/Providers/Http/Exception/ResponseExceptionTest.php @@ -0,0 +1,70 @@ +assertSame('Error while streaming the OpenAI API response: connection reset', $exception->getMessage()); + } + + /** + * Tests that fromStreamError chains the previous exception when given one. + */ + public function testFromStreamErrorChainsPreviousException(): void + { + $previous = new RuntimeException('underlying'); + + $exception = ResponseException::fromStreamError('OpenAI', 'connection reset', $previous); + + $this->assertSame($previous, $exception->getPrevious()); + } + + /** + * Tests that fromStreamError has no previous exception by default. + */ + public function testFromStreamErrorHasNoPreviousByDefault(): void + { + $exception = ResponseException::fromStreamError('OpenAI', 'connection reset'); + + $this->assertNull($exception->getPrevious()); + } + + /** + * Tests that fromMissingData builds the missing-key message. + */ + public function testFromMissingDataMessage(): void + { + $exception = ResponseException::fromMissingData('OpenAI', 'choices'); + + $this->assertSame('Unexpected OpenAI API response: Missing the "choices" key.', $exception->getMessage()); + } + + /** + * Tests that fromInvalidData builds the invalid-key message. + */ + public function testFromInvalidDataMessage(): void + { + $exception = ResponseException::fromInvalidData('OpenAI', 'usage', 'not an object'); + + $this->assertSame( + 'Unexpected OpenAI API response: Invalid "usage" key: not an object', + $exception->getMessage() + ); + } +} diff --git a/tests/unit/Providers/Http/HttpTransporterTest.php b/tests/unit/Providers/Http/HttpTransporterTest.php index e94073e5..1d975286 100644 --- a/tests/unit/Providers/Http/HttpTransporterTest.php +++ b/tests/unit/Providers/Http/HttpTransporterTest.php @@ -347,4 +347,188 @@ public function testSendMergesOptionsWithParameterPrecedence(): void $this->assertSame(5.0, $lastOptions['connect_timeout']); // From request (not overridden) $this->assertFalse($lastOptions['allow_redirects']); // From parameter (0 = disabled) } + + /** + * Creates a transporter backed by the given Guzzle-like client. + * + * @param GuzzleLikeClient $client The Guzzle-like client. + * @return HttpTransporter + */ + private function createGuzzleTransporter(GuzzleLikeClient $client): HttpTransporter + { + return new HttpTransporter($client, $this->httpFactory, $this->httpFactory); + } + + /** + * Tests that streaming forwards stream:true to Guzzle and passes the body through. + * + * @return void + */ + public function testStreamForwardsStreamOptionToGuzzleAndPassesBodyThrough(): void + { + $psr7Response = new Psr7Response(200, [], 'streamed body'); + $body = $psr7Response->getBody(); + $guzzleClient = new GuzzleLikeClient($psr7Response); + $transporter = $this->createGuzzleTransporter($guzzleClient); + + $options = new RequestOptions(); + $options->setStream(true); + + $request = new Request(HttpMethodEnum::GET(), 'https://api.example.com/stream'); + $result = $transporter->send($request, $options); + + $lastOptions = $guzzleClient->getLastOptions(); + $this->assertIsArray($lastOptions); + $this->assertTrue($lastOptions['stream']); + + $this->assertSame($body, $result->getStream()); + } + + /** + * Tests that a non-streaming Guzzle request omits the stream option and buffers the body. + * + * @return void + */ + public function testNonStreamGuzzleRequestOmitsStreamOptionAndBuffersBody(): void + { + $psr7Response = new Psr7Response(200, [], 'buffered body'); + $guzzleClient = new GuzzleLikeClient($psr7Response); + $transporter = $this->createGuzzleTransporter($guzzleClient); + + $options = new RequestOptions(); + $options->setStream(false); + + $request = new Request(HttpMethodEnum::GET(), 'https://api.example.com/buffered'); + $result = $transporter->send($request, $options); + + $lastOptions = $guzzleClient->getLastOptions(); + $this->assertIsArray($lastOptions); + $this->assertArrayNotHasKey('stream', $lastOptions); + + $this->assertSame('buffered body', $result->getBody()); + } + + /** + * Tests that a request-level stream option streams the response. + * + * @return void + */ + public function testRequestLevelStreamOptionStreamsResponse(): void + { + $psr7Response = new Psr7Response(200, [], 'streamed body'); + $body = $psr7Response->getBody(); + $guzzleClient = new GuzzleLikeClient($psr7Response); + $transporter = $this->createGuzzleTransporter($guzzleClient); + + $requestOptions = new RequestOptions(); + $requestOptions->setStream(true); + + $request = new Request(HttpMethodEnum::GET(), 'https://api.example.com/stream', [], null, $requestOptions); + $result = $transporter->send($request); + + $lastOptions = $guzzleClient->getLastOptions(); + $this->assertIsArray($lastOptions); + $this->assertTrue($lastOptions['stream']); + $this->assertSame($body, $result->getStream()); + } + + /** + * Tests that parameter stream:false overrides request stream:true. + * + * @return void + */ + public function testParameterStreamFalseOverridesRequestStreamTrue(): void + { + $psr7Response = new Psr7Response(200, [], 'buffered body'); + $guzzleClient = new GuzzleLikeClient($psr7Response); + $transporter = $this->createGuzzleTransporter($guzzleClient); + + $requestOptions = new RequestOptions(); + $requestOptions->setStream(true); + $request = new Request(HttpMethodEnum::GET(), 'https://api.example.com/x', [], null, $requestOptions); + + $parameterOptions = new RequestOptions(); + $parameterOptions->setStream(false); + + $result = $transporter->send($request, $parameterOptions); + + $lastOptions = $guzzleClient->getLastOptions(); + $this->assertIsArray($lastOptions); + $this->assertArrayNotHasKey('stream', $lastOptions); + $this->assertSame('buffered body', $result->getBody()); + } + + /** + * Tests that parameter stream:true overrides request stream:false. + * + * @return void + */ + public function testParameterStreamTrueOverridesRequestStreamFalse(): void + { + $psr7Response = new Psr7Response(200, [], 'streamed body'); + $body = $psr7Response->getBody(); + $guzzleClient = new GuzzleLikeClient($psr7Response); + $transporter = $this->createGuzzleTransporter($guzzleClient); + + $requestOptions = new RequestOptions(); + $requestOptions->setStream(false); + $request = new Request(HttpMethodEnum::GET(), 'https://api.example.com/x', [], null, $requestOptions); + + $parameterOptions = new RequestOptions(); + $parameterOptions->setStream(true); + + $result = $transporter->send($request, $parameterOptions); + + $lastOptions = $guzzleClient->getLastOptions(); + $this->assertIsArray($lastOptions); + $this->assertTrue($lastOptions['stream']); + $this->assertSame($body, $result->getStream()); + } + + /** + * Tests that a request-level stream flag is kept when parameter options omit it. + * + * @return void + */ + public function testRequestStreamIsMergedWhenParameterDoesNotSpecifyStream(): void + { + $psr7Response = new Psr7Response(200, [], 'streamed body'); + $body = $psr7Response->getBody(); + $guzzleClient = new GuzzleLikeClient($psr7Response); + $transporter = $this->createGuzzleTransporter($guzzleClient); + + $requestOptions = new RequestOptions(); + $requestOptions->setStream(true); + $request = new Request(HttpMethodEnum::GET(), 'https://api.example.com/x', [], null, $requestOptions); + + $parameterOptions = new RequestOptions(); + $parameterOptions->setTimeout(5.0); + + $result = $transporter->send($request, $parameterOptions); + + $lastOptions = $guzzleClient->getLastOptions(); + $this->assertIsArray($lastOptions); + $this->assertTrue($lastOptions['stream']); + $this->assertSame($body, $result->getStream()); + } + + /** + * Tests that the streamed response is passed through with a non-Guzzle client. + * + * @return void + */ + public function testStreamingResponseIsPassedThroughWithNonGuzzleClient(): void + { + $psr7Response = new Psr7Response(200, [], 'streamed body'); + $body = $psr7Response->getBody(); + $this->mockClient->addResponse($psr7Response); + + $options = new RequestOptions(); + $options->setStream(true); + + $request = new Request(HttpMethodEnum::GET(), 'https://api.example.com/stream'); + $result = $this->transporter->send($request, $options); + + $this->assertSame($body, $result->getStream()); + } } diff --git a/tests/unit/Providers/Http/Streaming/SseEventStreamParserTest.php b/tests/unit/Providers/Http/Streaming/SseEventStreamParserTest.php new file mode 100644 index 00000000..229b9641 --- /dev/null +++ b/tests/unit/Providers/Http/Streaming/SseEventStreamParserTest.php @@ -0,0 +1,549 @@ + $chunks The byte chunks to feed to the parser. + * @return list The parsed events. + */ + private function parse(array $chunks): array + { + return iterator_to_array((new SseEventStreamParser())->parse(new ChunkStream($chunks)), false); + } + + /** + * Parses a single body string into a list of events. + * + * @param string $body The full body. + * @return list The parsed events. + */ + private function parseBody(string $body): array + { + return $this->parse([$body]); + } + + /** + * Maps events to their data payloads. + * + * @param list $events The events. + * @return list The data payloads. + */ + private function dataOf(array $events): array + { + return array_map(static fn (ServerSentEvent $e): string => $e->getData(), $events); + } + + /** + * Tests field parsing across colons, NUL bytes, casing, and line endings. + * + * @return void + */ + public function testFieldParsing(): void + { + $body = "data:\x00\n" + . "data: 2\r" + . "Data:1\n" + . "data\x00:2\n" + . "data:1\r" + . "\x00data:4\n" + . "da-ta:3\r" + . "data_5\n" + . "data:3\r" + . "data:\r\n" + . " data:32\n" + . "data:4\n" + . "\n"; + + $events = $this->parseBody($body); + + $this->assertCount(1, $events); + $this->assertSame("\x00\n 2\n1\n3\n\n4", $events[0]->getData()); + } + + /** + * Tests the data field: empty values, missing colons, and accumulation. + * + * @return void + */ + public function testDataFieldVariants(): void + { + $events = $this->parseBody("data:\n\ndata\ndata\n\ndata:test\n\n"); + + $this->assertSame(['', "\n", 'test'], $this->dataOf($events)); + } + + /** + * Tests that a custom event name is carried, and the default is "message". + * + * @return void + */ + public function testCustomEventName(): void + { + $events = $this->parseBody("event:test\ndata:x\n\ndata:x\n\n"); + + $this->assertCount(2, $events); + $this->assertSame('test', $events[0]->getEvent()); + $this->assertSame('message', $events[1]->getEvent()); + } + + /** + * Tests that an empty event field falls back to the default "message" type. + * + * @return void + */ + public function testEmptyEventNameDefaultsToMessage(): void + { + $events = $this->parseBody("event: \ndata:data\n\n"); + + $this->assertCount(1, $events); + $this->assertSame('message', $events[0]->getEvent()); + $this->assertSame('data', $events[0]->getData()); + } + + /** + * Tests that comment and unknown lines are ignored among data lines. + * + * @return void + */ + public function testCommentsIgnored(): void + { + $long = str_repeat('x', 16); + $body = "data:1\r" + . ":\x00\n" + . ":\r\n" + . "data:2\n" + . ':' . $long . "\r" + . "data:3\n" + . ":data:fail\r" + . ':' . $long . "\n" + . "data:4\n" + . "\n"; + + $events = $this->parseBody($body); + + $this->assertCount(1, $events); + $this->assertSame("1\n2\n3\n4", $events[0]->getData()); + } + + /** + * Tests that unknown fields, leading-space field names, and comments are skipped. + * + * @return void + */ + public function testUnknownFieldsIgnored(): void + { + $body = "data:test\n" + . " data\n" + . "data\n" + . "foobar:xxx\n" + . "justsometext\n" + . ":thisisacommentyay\n" + . "data:test\n" + . "\n"; + + $events = $this->parseBody($body); + + $this->assertCount(1, $events); + $this->assertSame("test\n\ntest", $events[0]->getData()); + } + + /** + * Tests that only one leading space is stripped, a tab is kept, and CR ends a line. + * + * @return void + */ + public function testLeadingSpaceStrippedOnce(): void + { + $events = $this->parseBody("data:\ttest\rdata: \ndata:test\n\n"); + + $this->assertCount(1, $events); + $this->assertSame("\ttest\n\ntest", $events[0]->getData()); + } + + /** + * Tests CRLF, LF, and a lone CR are all treated as line terminators. + * + * @return void + */ + public function testNewlineVariants(): void + { + $events = $this->parseBody("data:test\r\ndata\ndata:test\r\n\r\n"); + + $this->assertCount(1, $events); + $this->assertSame("test\n\ntest", $events[0]->getData()); + } + + /** + * Tests that a NUL byte is preserved in the data payload. + * + * @return void + */ + public function testNullCharacterInData(): void + { + $events = $this->parseBody("data:\x00\n\n"); + + $this->assertCount(1, $events); + $this->assertSame("\x00", $events[0]->getData()); + } + + /** + * Tests that multi-byte UTF-8 data is passed through unchanged. + * + * @return void + */ + public function testUtf8DataPreserved(): void + { + $events = $this->parseBody("data:ok\xE2\x80\xA6\n\n"); + + $this->assertCount(1, $events); + $this->assertSame('ok…', $events[0]->getData()); + } + + /** + * Tests that the id field sets the event ID. + * + * @return void + */ + public function testIdFieldSetsId(): void + { + $events = $this->parseBody("id:abc\ndata:x\n\n"); + + $this->assertCount(1, $events); + $this->assertSame('abc', $events[0]->getId()); + } + + /** + * Tests that an id containing a NUL byte is ignored. + * + * @dataProvider provideNulIds + * + * @param string $idValue The id field value. + * @return void + */ + public function testIdWithNulIgnored(string $idValue): void + { + $events = $this->parseBody('id:' . $idValue . "\ndata:hello\n\n"); + + $this->assertCount(1, $events); + $this->assertSame('', $events[0]->getId()); + $this->assertSame('hello', $events[0]->getData()); + } + + /** + * @return array + */ + public function provideNulIds(): array + { + return [ + 'two nulls' => ["\x00\x00"], + 'trailing null' => ["x\x00"], + 'leading null' => ["\x00x"], + 'embedded null' => ["x\x00x"], + 'space then null' => [" \x00"], + ]; + } + + /** + * Tests that the last event ID persists across events and resets on an empty id. + * + * @return void + */ + public function testIdPersistsAndResets(): void + { + $body = "id:1\ndata:1\n\n" + . "data:2\n\n" + . "id\ndata:3\n\n" + . "id:2\ndata:4\n\n"; + + $events = $this->parseBody($body); + + $ids = array_map(static fn (ServerSentEvent $e): string => $e->getId(), $events); + $this->assertSame(['1', '1', '', '2'], $ids); + $this->assertSame(['1', '2', '3', '4'], $this->dataOf($events)); + } + + /** + * Tests retry field parsing, including decimal-not-octal and bogus values. + * + * @dataProvider provideRetry + * + * @param string $body The body. + * @param int|null $expected The expected retry value. + * @return void + */ + public function testRetryField(string $body, ?int $expected): void + { + $events = $this->parseBody($body); + + $this->assertCount(1, $events); + $this->assertSame($expected, $events[0]->getRetry()); + } + + /** + * @return array + */ + public function provideRetry(): array + { + return [ + 'plain' => ["retry:3000\ndata:x\n\n", 3000], + 'leading zero is decimal not octal' => ["retry:03000\ndata:x\n\n", 3000], + 'bogus is ignored' => ["retry:1000x\ndata:x\n\n", null], + 'bogus keeps previous value' => ["retry:3000\nretry:1000x\ndata:x\n\n", 3000], + 'empty retry field' => ["retry\ndata:x\n\n", null], + ]; + } + + /** + * Tests that a leading BOM is stripped once while a mid-stream BOM is literal. + * + * @return void + */ + public function testLeadingBomStrippedOnce(): void + { + $bom = "\xEF\xBB\xBF"; + $events = $this->parseBody($bom . "data:1\n\n" . $bom . "data:2\n\ndata:3\n\n"); + + $this->assertSame(['1', '3'], $this->dataOf($events)); + } + + /** + * Tests that only the first of a double BOM is stripped. + * + * @return void + */ + public function testDoubleBomStripsOnlyOne(): void + { + $bom = "\xEF\xBB\xBF"; + $events = $this->parseBody($bom . $bom . "data:1\n\ndata:2\n\ndata:3\n\n"); + + $this->assertSame(['2', '3'], $this->dataOf($events)); + } + + /** + * Tests that an event left pending at EOF (no final blank line) is discarded. + * + * @return void + */ + public function testIncompleteFinalEventDiscarded(): void + { + $events = $this->parseBody("retry:1000\ndata:test1\n\nid:test\ndata:test2\n"); + + $this->assertCount(1, $events); + $this->assertSame('test1', $events[0]->getData()); + $this->assertSame('', $events[0]->getId()); + $this->assertSame(1000, $events[0]->getRetry()); + } + + /** + * Tests line and data parsing across a fuller mixed stream. + * + * @return void + */ + public function testMixedStreamLinesAndData(): void + { + $body = "data:msg\n" + . "data: msg\n\n" + . ":\n" + . "falsefield:msg\n\n" + . "falsefield:msg\n" + . "Data:data\n\n" + . "data\n\n" + . "data:end\n\n"; + + $events = $this->parseBody($body); + + $this->assertSame(["msg\nmsg", '', 'end'], $this->dataOf($events)); + } + + /** + * Tests that an empty stream yields no events. + * + * @return void + */ + public function testEmptyStream(): void + { + $this->assertSame([], $this->parse([])); + $this->assertSame([], $this->parseBody('')); + } + + /** + * Tests that a frame split across reads is reassembled. + * + * @return void + */ + public function testFrameSplitAcrossChunks(): void + { + $events = $this->parse(['data: hel', 'lo wor', "ld\n", "\n"]); + + $this->assertCount(1, $events); + $this->assertSame('hello world', $events[0]->getData()); + } + + /** + * Tests a CRLF terminator split across two reads (CR ends one chunk, LF starts the next). + * + * @return void + */ + public function testCrlfSplitAcrossChunks(): void + { + $events = $this->parse(["data:x\r", "\n\r\n"]); + + $this->assertCount(1, $events); + $this->assertSame('x', $events[0]->getData()); + } + + /** + * Tests a lone CR separating two data lines within one event. + * + * @return void + */ + public function testLoneCrSeparatesDataLines(): void + { + $events = $this->parseBody("data: a\rdata: b\n\n"); + + $this->assertCount(1, $events); + $this->assertSame("a\nb", $events[0]->getData()); + } + + /** + * Tests a multi-byte UTF-8 character split across reads. + * + * @return void + */ + public function testMultibyteUtf8SplitAcrossChunks(): void + { + $events = $this->parse(["data:a\xE2\x80", "\xA6b\n\n"]); + + $this->assertCount(1, $events); + $this->assertSame('a…b', $events[0]->getData()); + } + + /** + * Tests a BOM split across reads is still stripped. + * + * @return void + */ + public function testBomSplitAcrossChunks(): void + { + $events = $this->parse(["\xEF", "\xBB\xBF", "data:x\n\n"]); + + $this->assertCount(1, $events); + $this->assertSame('x', $events[0]->getData()); + } + + /** + * Tests that a final line without a terminator is parsed then discarded at EOF. + * + * @return void + */ + public function testFinalLineWithoutTerminatorIsDiscarded(): void + { + $events = $this->parseBody("data:x\n\ndata:y"); + + $this->assertCount(1, $events); + $this->assertSame('x', $events[0]->getData()); + } + + /** + * Tests that a final line ending in a lone CR at EOF is parsed then discarded. + * + * @return void + */ + public function testFinalLineEndingWithLoneCrIsDiscarded(): void + { + $events = $this->parseBody("data:x\n\ndata:y\r"); + + $this->assertCount(1, $events); + $this->assertSame('x', $events[0]->getData()); + } + + /** + * Tests that a sole event with no terminating blank line yields nothing. + * + * Some SSE parsers emit the final event even when its terminator is missing. + * The WHATWG spec discards an event left pending at EOF, so a stream whose + * only event is never closed produces no events at all. + * + * @return void + */ + public function testSoleIncompleteEventYieldsNothing(): void + { + $this->assertSame([], $this->parseBody('data:hello')); + } + + /** + * Tests a first read shorter than the BOM that is not a BOM prefix. + * + * @return void + */ + public function testFirstReadShorterThanBomPrefix(): void + { + $events = $this->parse(['d', "ata:x\n\n"]); + + $this->assertCount(1, $events); + $this->assertSame('x', $events[0]->getData()); + } + + /** + * Tests that the stream is closed after it is fully consumed. + * + * @return void + */ + public function testStreamClosedAfterConsumption(): void + { + $stream = new ChunkStream(["data:x\n\n"]); + iterator_to_array((new SseEventStreamParser())->parse($stream), false); + + $this->assertTrue($stream->isClosed()); + } + + /** + * Tests that the stream is closed when iteration stops early. + * + * @return void + */ + public function testStreamClosedOnEarlyAbandon(): void + { + $stream = new ChunkStream(["data:1\n\ndata:2\n\n"]); + + foreach ((new SseEventStreamParser())->parse($stream) as $event) { + break; + } + + $this->assertTrue($stream->isClosed()); + } + + /** + * Tests that events are produced lazily as the stream is read. + * + * @return void + */ + public function testLazyConsumption(): void + { + $stream = new ChunkStream(["data:1\n\n", "data:2\n\n"]); + $generator = (new SseEventStreamParser())->parse($stream); + + $this->assertSame('1', $generator->current()->getData()); + $this->assertSame(1, $stream->getReadCount()); + + $generator->next(); + $this->assertSame('2', $generator->current()->getData()); + $this->assertSame(2, $stream->getReadCount()); + } +} diff --git a/tests/unit/Providers/Http/Streaming/ValueObjects/ServerSentEventTest.php b/tests/unit/Providers/Http/Streaming/ValueObjects/ServerSentEventTest.php new file mode 100644 index 00000000..1bd44272 --- /dev/null +++ b/tests/unit/Providers/Http/Streaming/ValueObjects/ServerSentEventTest.php @@ -0,0 +1,40 @@ +assertSame('completion', $event->getEvent()); + $this->assertSame('{"x":1}', $event->getData()); + $this->assertSame('evt-1', $event->getId()); + $this->assertSame(2000, $event->getRetry()); + } + + /** + * Tests that the id defaults to an empty string and retry to null. + */ + public function testDefaults(): void + { + $event = new ServerSentEvent('message', 'payload'); + + $this->assertSame('message', $event->getEvent()); + $this->assertSame('payload', $event->getData()); + $this->assertSame('', $event->getId()); + $this->assertNull($event->getRetry()); + } +} diff --git a/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModelTest.php b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModelTest.php index 6c99c75b..7f3b4c27 100644 --- a/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModelTest.php +++ b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModelTest.php @@ -6,6 +6,7 @@ use InvalidArgumentException; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Exception\RuntimeException; use WordPress\AiClient\Files\DTO\File; use WordPress\AiClient\Messages\DTO\Message; use WordPress\AiClient\Messages\DTO\MessagePart; @@ -15,6 +16,8 @@ use WordPress\AiClient\Providers\DTO\ProviderMetadata; use WordPress\AiClient\Providers\Http\Contracts\HttpTransporterInterface; use WordPress\AiClient\Providers\Http\Contracts\RequestAuthenticationInterface; +use WordPress\AiClient\Providers\Http\DTO\Request; +use WordPress\AiClient\Providers\Http\DTO\RequestOptions; use WordPress\AiClient\Providers\Http\DTO\Response; use WordPress\AiClient\Providers\Http\Exception\ClientException; use WordPress\AiClient\Providers\Http\Exception\ResponseException; @@ -22,7 +25,13 @@ use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Results\DTO\Candidate; use WordPress\AiClient\Results\DTO\GenerativeAiResult; +use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Results\StreamedGenerativeAiResult; +use WordPress\AiClient\Results\ValueObjects\CandidateDelta; +use WordPress\AiClient\Results\ValueObjects\GenerativeAiResultChunk; +use WordPress\AiClient\Tests\mocks\ChunkStream; +use WordPress\AiClient\Tests\mocks\FailingChunkStream; use WordPress\AiClient\Tools\DTO\FunctionCall; use WordPress\AiClient\Tools\DTO\FunctionDeclaration; use WordPress\AiClient\Tools\DTO\FunctionResponse; @@ -165,6 +174,38 @@ public function testGenerateTextResultApiFailure(): void $model->generateTextResult($prompt); } + /** + * Tests that token usage defaults to zero when the response omits usage. + * + * @return void + */ + public function testGenerateTextResultDefaultsTokenUsageWhenUsageAbsent(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Hello')])]; + $response = new Response( + 200, + [], + json_encode([ + 'id' => 'chatcmpl-123', + 'choices' => [ + [ + 'message' => ['role' => 'assistant', 'content' => 'Hi there!'], + 'finish_reason' => 'stop', + ], + ], + ]) + ); + + $this->mockRequestAuthentication->method('authenticateRequest')->willReturnArgument(0); + $this->mockHttpTransporter->method('send')->willReturn($response); + + $result = $this->createModel()->generateTextResult($prompt); + + $this->assertSame(0, $result->getTokenUsage()->getPromptTokens()); + $this->assertSame(0, $result->getTokenUsage()->getCompletionTokens()); + $this->assertSame(0, $result->getTokenUsage()->getTotalTokens()); + } + /** * Tests prepareGenerateTextParams() with basic text prompt. * @@ -1328,4 +1369,838 @@ public function testGetMessagePartContentDataThoughtPart(): void // Should be skipped because OpenAI API doesn't support receiving thoughts. $this->assertNull($data); } + + /** + * @return list + */ + private function createStreamPrompt(): array + { + return [new Message(MessageRoleEnum::user(), [new MessagePart('Hello')])]; + } + + /** + * Builds one SSE "data:" frame for the given decoded payload. + * + * @param array $payload + */ + private function createSseDataLine(array $payload): string + { + return 'data: ' . json_encode($payload) . "\n\n"; + } + + /** + * Wraps SSE frames in a streamed Response, one frame per read. + * + * @param list $sseFrames + */ + private function createStreamResponse(array $sseFrames, int $statusCode = 200): Response + { + return new Response($statusCode, [], new ChunkStream($sseFrames)); + } + + /** + * Configures auth passthrough and the transporter to return the given response. + */ + private function givenStreamResponse(Response $response): void + { + $this->mockRequestAuthentication->method('authenticateRequest')->willReturnArgument(0); + $this->mockHttpTransporter->method('send')->willReturn($response); + } + + /** + * Drains a handle into a list of chunks. + * + * @return list + */ + private function consumeChunks(StreamedGenerativeAiResult $handle): array + { + return array_values(iterator_to_array($handle, false)); + } + + /** + * Tests that creating the handle does not perform the request. + */ + public function testStreamGenerateTextResultReturnsHandleWithoutPerformingRequest(): void + { + $this->mockHttpTransporter->expects($this->never())->method('send'); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + + $this->assertInstanceOf(StreamedGenerativeAiResult::class, $handle); + } + + /** + * Tests that the streamed request enables streaming and usage reporting. + */ + public function testStreamGenerateTextResultEnablesStreamingOnTheRequest(): void + { + $this->mockRequestAuthentication + ->expects($this->once()) + ->method('authenticateRequest') + ->willReturnArgument(0); + + $this->mockHttpTransporter + ->expects($this->once()) + ->method('send') + ->willReturnCallback(function (Request $request, ?RequestOptions $options = null) { + $params = $request->getData() ?? []; + $this->assertTrue($params['stream'] ?? null); + $this->assertSame(['include_usage' => true], $params['stream_options'] ?? null); + $this->assertInstanceOf(RequestOptions::class, $options); + $this->assertTrue($options->isStream()); + + return $this->createStreamResponse(["data: [DONE]\n\n"]); + }); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + + $this->assertSame([], $this->consumeChunks($handle)); + } + + /** + * Tests that content deltas, the finish reason, and usage are assembled. + */ + public function testStreamAssemblesContentFinishReasonAndUsage(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine([ + 'id' => 'chatcmpl-1', + 'choices' => [['index' => 0, 'delta' => ['role' => 'assistant', 'content' => 'Hel']]], + ]), + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['content' => 'lo']]]]), + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => [], 'finish_reason' => 'stop']]]), + $this->createSseDataLine([ + 'choices' => [], + 'usage' => ['prompt_tokens' => 10, 'completion_tokens' => 5, 'total_tokens' => 15], + ]), + "data: [DONE]\n\n", + ])); + + $result = $this->createModel() + ->streamGenerateTextResult($this->createStreamPrompt()) + ->getFinalResult(); + + $this->assertSame('chatcmpl-1', $result->getId()); + $this->assertSame('Hello', $result->toText()); + $this->assertTrue($result->getCandidates()[0]->getFinishReason()->is(FinishReasonEnum::stop())); + $this->assertSame(10, $result->getTokenUsage()->getPromptTokens()); + $this->assertSame(5, $result->getTokenUsage()->getCompletionTokens()); + $this->assertSame(15, $result->getTokenUsage()->getTotalTokens()); + } + + /** + * Tests that each event yields a chunk and the [DONE] sentinel is skipped. + */ + public function testStreamYieldsOneChunkPerEventAndSkipsTheDoneSentinel(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['content' => 'Hel']]]]), + "data: \n\n", + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['content' => 'lo']]]]), + $this->createSseDataLine([ + 'choices' => [], + 'usage' => ['prompt_tokens' => 1, 'completion_tokens' => 2, 'total_tokens' => 3], + ]), + "data: [DONE]\n\n", + ])); + + $chunks = $this->consumeChunks( + $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()) + ); + + $this->assertCount(3, $chunks); + $this->assertSame('Hel', $chunks[0]->getDeltaText()); + $this->assertSame('lo', $chunks[1]->getDeltaText()); + $this->assertSame('', $chunks[2]->getDeltaText()); + $this->assertInstanceOf(TokenUsage::class, $chunks[2]->getTokenUsage()); + $this->assertSame(3, $chunks[2]->getTokenUsage()->getTotalTokens()); + } + + /** + * Tests that additional data is extracted onto chunks and the result. + */ + public function testStreamExtractsAdditionalDataIntoChunksAndResult(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine([ + 'id' => 'chatcmpl-1', + 'object' => 'chat.completion.chunk', + 'system_fingerprint' => 'fp_abc', + 'choices' => [['index' => 0, 'delta' => ['content' => 'Hi'], 'finish_reason' => 'stop']], + 'usage' => ['prompt_tokens' => 1, 'completion_tokens' => 1, 'total_tokens' => 2], + ]), + "data: [DONE]\n\n", + ])); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + $chunks = $this->consumeChunks($handle); + + $expected = ['object' => 'chat.completion.chunk', 'system_fingerprint' => 'fp_abc']; + $this->assertSame($expected, $chunks[0]->getAdditionalData()); + $this->assertSame($expected, $handle->getFinalResult()->getAdditionalData()); + } + + /** + * Tests that reasoning deltas route to the thought channel with thought tokens. + */ + public function testStreamRoutesReasoningToThoughtChannelWithThoughtTokens(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['reasoning_content' => 'Think']]]]), + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['reasoning' => 'ing']]]]), + $this->createSseDataLine([ + 'choices' => [['index' => 0, 'delta' => ['content' => 'Answer'], 'finish_reason' => 'stop']], + ]), + $this->createSseDataLine([ + 'choices' => [], + 'usage' => [ + 'prompt_tokens' => 15, + 'completion_tokens' => 20, + 'total_tokens' => 35, + 'completion_tokens_details' => ['reasoning_tokens' => 10], + ], + ]), + "data: [DONE]\n\n", + ])); + + $result = $this->createModel() + ->streamGenerateTextResult($this->createStreamPrompt()) + ->getFinalResult(); + + $thought = ''; + $content = ''; + foreach ($result->getCandidates()[0]->getMessage()->getParts() as $part) { + if ($part->getText() === null) { + continue; + } + if ($part->getChannel()->is(MessagePartChannelEnum::thought())) { + $thought .= $part->getText(); + } else { + $content .= $part->getText(); + } + } + + $this->assertSame('Thinking', $thought); + $this->assertSame('Answer', $content); + $this->assertSame(10, $result->getTokenUsage()->getThoughtTokens()); + } + + /** + * Builds one streamed-choice SSE frame carrying tool-call deltas. + * + * @param list $toolCalls + */ + private function toolCallFrame(array $toolCalls, ?string $finishReason = null): string + { + $deltas = []; + foreach ($toolCalls as [$index, $id, $name, $arguments]) { + $function = []; + if ($name !== null) { + $function['name'] = $name; + } + if ($arguments !== null) { + $function['arguments'] = $arguments; + } + + $toolCall = ['index' => $index]; + if ($id !== null) { + $toolCall['id'] = $id; + } + $toolCall['function'] = $function; + $deltas[] = $toolCall; + } + + $choice = ['index' => 0, 'delta' => ['tool_calls' => $deltas]]; + if ($finishReason !== null) { + $choice['finish_reason'] = $finishReason; + } + + return $this->createSseDataLine(['choices' => [$choice]]); + } + + /** + * @return array, 1: list}> + */ + public function assembledToolCallProvider(): array + { + return [ + 'arguments split across frames' => [ + [ + $this->toolCallFrame([[0, 'call_1', 'get_weather', '']]), + $this->toolCallFrame([[0, null, null, '{"ci']]), + $this->toolCallFrame([[0, null, null, 'ty": "San']]), + $this->toolCallFrame([[0, null, null, ' Francisco"}']], 'tool_calls'), + ], + [['id' => 'call_1', 'name' => 'get_weather', 'args' => ['city' => 'San Francisco']]], + ], + 'whole tool call in one frame' => [ + [ + $this->toolCallFrame([[0, 'call_1', 'get_weather', '{"city": "London"}']], 'tool_calls'), + ], + [['id' => 'call_1', 'name' => 'get_weather', 'args' => ['city' => 'London']]], + ], + 'missing type field (Azure AI Foundry / Mistral)' => [ + [ + $this->toolCallFrame([[0, 'call_abc', 'test-tool', '{"value"']]), + $this->toolCallFrame([[0, null, null, ':"hello"}']], 'tool_calls'), + ], + [['id' => 'call_abc', 'name' => 'test-tool', 'args' => ['value' => 'hello']]], + ], + 'trailing empty argument frame does not duplicate' => [ + [ + $this->toolCallFrame([[0, 'call_1', 'searchGoogle', null]]), + $this->toolCallFrame([[0, null, null, '{"query": "ai"}']]), + $this->toolCallFrame([[0, null, null, '']], 'tool_calls'), + ], + [['id' => 'call_1', 'name' => 'searchGoogle', 'args' => ['query' => 'ai']]], + ], + 'parallel tool calls reassembled independently' => [ + [ + $this->toolCallFrame([ + [0, 'call_a', 'get_weather', '{"city":'], + [1, 'call_b', 'get_time', '{"tz":'], + ]), + $this->toolCallFrame([[1, null, null, '"UTC"}'], [0, null, null, '"Paris"}']], 'tool_calls'), + ], + [ + ['id' => 'call_a', 'name' => 'get_weather', 'args' => ['city' => 'Paris']], + ['id' => 'call_b', 'name' => 'get_time', 'args' => ['tz' => 'UTC']], + ], + ], + ]; + } + + /** + * Tests that tool-call deltas are reassembled into function calls. + * + * @dataProvider assembledToolCallProvider + * + * @param list $sseFrames + * @param list $expectedCalls + */ + public function testStreamReassemblesToolCalls(array $sseFrames, array $expectedCalls): void + { + $this->givenStreamResponse($this->createStreamResponse(array_merge($sseFrames, ["data: [DONE]\n\n"]))); + + $candidate = $this->createModel() + ->streamGenerateTextResult($this->createStreamPrompt()) + ->getFinalResult() + ->getCandidates()[0]; + + $actualCalls = []; + foreach ($candidate->getMessage()->getParts() as $part) { + $call = $part->getFunctionCall(); + if ($call !== null) { + $actualCalls[] = ['id' => $call->getId(), 'name' => $call->getName(), 'args' => $call->getArgs()]; + } + } + + $this->assertEquals($expectedCalls, $actualCalls); + } + + /** + * Tests that choices at different indices become separate candidates. + */ + public function testStreamSeparatesMultipleCandidatesByIndex(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine(['choices' => [ + ['index' => 0, 'delta' => ['content' => 'First']], + ['index' => 1, 'delta' => ['content' => 'Second']], + ]]), + $this->createSseDataLine(['choices' => [ + ['index' => 0, 'delta' => [], 'finish_reason' => 'stop'], + ['index' => 1, 'delta' => [], 'finish_reason' => 'stop'], + ]]), + "data: [DONE]\n\n", + ])); + + $candidates = $this->createModel() + ->streamGenerateTextResult($this->createStreamPrompt()) + ->getFinalResult() + ->getCandidates(); + + $this->assertCount(2, $candidates); + $this->assertSame('First', $candidates[0]->getMessage()->getParts()[0]->getText()); + $this->assertSame('Second', $candidates[1]->getMessage()->getParts()[0]->getText()); + } + + /** + * Tests that an unknown finish reason defaults to stop. + */ + public function testStreamUnknownFinishReasonDefaultsToStop(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine([ + 'choices' => [['index' => 0, 'delta' => ['content' => 'Hi'], 'finish_reason' => 'something_new']], + ]), + "data: [DONE]\n\n", + ])); + + $candidate = $this->createModel() + ->streamGenerateTextResult($this->createStreamPrompt()) + ->getFinalResult() + ->getCandidates()[0]; + + $this->assertTrue($candidate->getFinishReason()->is(FinishReasonEnum::stop())); + } + + /** + * Tests that a malformed JSON frame is skipped without aborting the stream. + */ + public function testStreamSkipsUnparsableJsonLineButKeepsValidChunks(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['content' => 'Hel']]]]), + "data: {unparsable}\n\n", + $this->createSseDataLine([ + 'choices' => [['index' => 0, 'delta' => ['content' => 'lo'], 'finish_reason' => 'stop']], + ]), + "data: [DONE]\n\n", + ])); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + $chunks = $this->consumeChunks($handle); + + $this->assertCount(2, $chunks); + $this->assertSame('Hello', $handle->getFinalResult()->toText()); + } + + /** + * Tests that a [DONE]-only stream produces no result. + */ + public function testStreamWithOnlyDoneSentinelProducesNoResult(): void + { + $this->givenStreamResponse($this->createStreamResponse(["data: [DONE]\n\n"])); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('no candidates'); + $this->createModel()->streamGenerateTextResult($this->createStreamPrompt())->getFinalResult(); + } + + /** + * Tests that an error frame raises a ResponseException with the provider message. + */ + public function testStreamThrowsResponseExceptionOnErrorEvent(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine(['error' => ['message' => 'bad request', 'type' => 'provider_error']]), + "data: [DONE]\n\n", + ])); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + + $thrown = null; + try { + $this->consumeChunks($handle); + } catch (ResponseException $e) { + $thrown = $e; + } + + $this->assertInstanceOf(ResponseException::class, $thrown); + $this->assertSame('Error while streaming the TestProvider API response: bad request', $thrown->getMessage()); + } + + /** + * Tests that chunks before a mid-stream error are delivered, then it propagates. + */ + public function testStreamYieldsContentBeforeMidStreamErrorThenThrows(): void + { + $this->givenStreamResponse($this->createStreamResponse([ + $this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['content' => 'Hello']]]]), + $this->createSseDataLine(['error' => ['message' => 'stream failed after output']]), + "data: [DONE]\n\n", + ])); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + + $collected = []; + $thrown = null; + try { + foreach ($handle as $chunk) { + $collected[] = $chunk; + } + } catch (ResponseException $e) { + $thrown = $e; + } + + $this->assertCount(1, $collected); + $this->assertSame('Hello', $collected[0]->getDeltaText()); + $this->assertInstanceOf(ResponseException::class, $thrown); + $this->assertStringContainsString('stream failed after output', $thrown->getMessage()); + } + + /** + * Tests that a non-successful response is surfaced before streaming begins. + */ + public function testStreamThrowsClientExceptionWhenResponseIsNotSuccessful(): void + { + $this->givenStreamResponse(new Response(400, [], '{"error": "Invalid parameter."}')); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + + $this->expectException(ClientException::class); + $this->consumeChunks($handle); + } + + /** + * Tests that a mid-read failure is wrapped as a ResponseException with its cause. + */ + public function testStreamWrapsMidReadFailureAsResponseException(): void + { + $response = new Response(200, [], new FailingChunkStream( + [$this->createSseDataLine(['choices' => [['index' => 0, 'delta' => ['content' => 'Hello']]]])], + 'Connection reset by peer' + )); + $this->givenStreamResponse($response); + + $handle = $this->createModel()->streamGenerateTextResult($this->createStreamPrompt()); + + $collected = []; + $thrown = null; + try { + foreach ($handle as $chunk) { + $collected[] = $chunk; + } + } catch (ResponseException $e) { + $thrown = $e; + } + + $this->assertCount(1, $collected); + $this->assertSame('Hello', $collected[0]->getDeltaText()); + $this->assertInstanceOf(ResponseException::class, $thrown); + $this->assertStringContainsString('Connection reset by peer', $thrown->getMessage()); + $this->assertInstanceOf(\RuntimeException::class, $thrown->getPrevious()); + } + + /** + * Tests that throwIfStreamError() is a no-op without an error payload. + */ + public function testThrowIfStreamErrorIgnoresEventsWithoutError(): void + { + $model = $this->createModel(); + $model->exposeThrowIfStreamError(['choices' => []]); + + $this->assertTrue(true); + } + + /** + * @return array, 1: string}> + */ + public function streamErrorMessageProvider(): array + { + return [ + 'object error with message' => [ + ['error' => ['message' => 'boom', 'type' => 'server_error']], + 'Error while streaming the TestProvider API response: boom', + ], + 'object error without message' => [ + ['error' => ['type' => 'server_error']], + 'Error while streaming the TestProvider API response: The provider reported an error.', + ], + 'non-array error' => [ + ['error' => 'oops'], + 'Error while streaming the TestProvider API response: The provider reported an error.', + ], + 'non-string message' => [ + ['error' => ['message' => 123]], + 'Error while streaming the TestProvider API response: The provider reported an error.', + ], + ]; + } + + /** + * @dataProvider streamErrorMessageProvider + * + * @param array $event + */ + public function testThrowIfStreamErrorMessage(array $event, string $expectedMessage): void + { + $model = $this->createModel(); + + $this->expectException(ResponseException::class); + $this->expectExceptionMessage($expectedMessage); + $model->exposeThrowIfStreamError($event); + } + + /** + * Tests that parseStreamEvent() returns null for an unusable event. + */ + public function testParseStreamEventReturnsNullWhenEventCarriesNothingUsable(): void + { + $model = $this->createModel(); + + $this->assertNull($model->exposeParseStreamEvent([])); + $this->assertNull($model->exposeParseStreamEvent(['choices' => []])); + } + + /** + * Tests that parseStreamEvent() ignores a non-string id and non-array usage and choices. + */ + public function testParseStreamEventIgnoresNonStringIdAndNonArrayUsageAndChoices(): void + { + $model = $this->createModel(); + + $chunk = $model->exposeParseStreamEvent([ + 'id' => 123, + 'usage' => 'nope', + 'choices' => 'nope', + 'system_fingerprint' => 'fp_1', + ]); + + $this->assertInstanceOf(GenerativeAiResultChunk::class, $chunk); + $this->assertNull($chunk->getId()); + $this->assertNull($chunk->getTokenUsage()); + $this->assertSame([], $chunk->getCandidateDeltas()); + $this->assertSame(['system_fingerprint' => 'fp_1'], $chunk->getAdditionalData()); + } + + /** + * Tests that parseStreamEvent() skips non-array choice entries. + */ + public function testParseStreamEventSkipsNonArrayChoiceEntries(): void + { + $model = $this->createModel(); + + $chunk = $model->exposeParseStreamEvent([ + 'choices' => [ + 'not-an-array', + ['index' => 0, 'delta' => ['content' => 'Hi']], + ], + ]); + + $this->assertInstanceOf(GenerativeAiResultChunk::class, $chunk); + $this->assertCount(1, $chunk->getCandidateDeltas()); + $this->assertSame('Hi', $chunk->getDeltaText()); + } + + /** + * @return array, 1: int, 2: bool}> + */ + public function streamChoiceGuardProvider(): array + { + return [ + 'missing index defaults to 0' => [['delta' => ['content' => 'x']], 0, false], + 'non-int index defaults to 0' => [['index' => '5', 'delta' => ['content' => 'x']], 0, false], + 'non-array delta yields no parts' => [['index' => 2, 'delta' => 'nope'], 2, false], + 'non-string finish reason is dropped' => [['index' => 0, 'delta' => [], 'finish_reason' => 7], 0, false], + 'unknown finish reason is dropped' => [ + ['index' => 0, 'delta' => [], 'finish_reason' => 'mystery'], + 0, + false, + ], + 'known finish reason is kept' => [['index' => 0, 'delta' => [], 'finish_reason' => 'stop'], 0, true], + ]; + } + + /** + * @dataProvider streamChoiceGuardProvider + * + * @param array $choice + */ + public function testParseStreamChoiceGuards(array $choice, int $expectedIndex, bool $hasFinishReason): void + { + $delta = $this->createModel()->exposeParseStreamChoice($choice); + + $this->assertInstanceOf(CandidateDelta::class, $delta); + $this->assertSame($expectedIndex, $delta->getIndex()); + if ($hasFinishReason) { + $this->assertInstanceOf(FinishReasonEnum::class, $delta->getFinishReason()); + } else { + $this->assertNull($delta->getFinishReason()); + } + } + + /** + * Tests that parseStreamDeltaParts() maps channels and ignores non-string values. + */ + public function testParseStreamDeltaPartsMapsChannels(): void + { + $parts = $this->createModel()->exposeParseStreamDeltaParts([ + 'reasoning_content' => 'A', + 'reasoning' => 'B', + 'content' => 'C', + ]); + + $this->assertCount(3, $parts); + $this->assertSame('A', $parts[0]->getText()); + $this->assertTrue($parts[0]->getChannel()->is(MessagePartChannelEnum::thought())); + $this->assertSame('B', $parts[1]->getText()); + $this->assertTrue($parts[1]->getChannel()->is(MessagePartChannelEnum::thought())); + $this->assertSame('C', $parts[2]->getText()); + $this->assertTrue($parts[2]->getChannel()->is(MessagePartChannelEnum::content())); + + $this->assertSame([], $this->createModel()->exposeParseStreamDeltaParts([ + 'reasoning_content' => 1, + 'reasoning' => 2, + 'content' => 3, + ])); + } + + /** + * @return array, 1: array}> + */ + public function streamToolCallDeltaGuardProvider(): array + { + return [ + 'tool_calls not an array' => [ + ['tool_calls' => 'nope'], + ['count' => 0], + ], + 'tool call entry not an array' => [ + ['tool_calls' => ['nope']], + ['count' => 0], + ], + 'missing index falls back to position' => [ + ['tool_calls' => [['id' => 'a', 'function' => ['name' => 'fn', 'arguments' => '{}']]]], + ['count' => 1, 'index' => 0, 'id' => 'a', 'name' => 'fn', 'arguments' => '{}'], + ], + 'non-int index falls back to position' => [ + ['tool_calls' => [['index' => 'x', 'id' => 'a', 'function' => ['name' => 'fn']]]], + ['count' => 1, 'index' => 0, 'id' => 'a', 'name' => 'fn', 'arguments' => ''], + ], + 'non-string id becomes null' => [ + ['tool_calls' => [['index' => 0, 'id' => 9, 'function' => ['name' => 'fn']]]], + ['count' => 1, 'index' => 0, 'id' => null, 'name' => 'fn', 'arguments' => ''], + ], + 'non-array function yields null name and empty arguments' => [ + ['tool_calls' => [['index' => 0, 'id' => 'a', 'function' => 'nope']]], + ['count' => 1, 'index' => 0, 'id' => 'a', 'name' => null, 'arguments' => ''], + ], + 'non-string name and arguments are dropped' => [ + ['tool_calls' => [['index' => 0, 'id' => 'a', 'function' => ['name' => 1, 'arguments' => 2]]]], + ['count' => 1, 'index' => 0, 'id' => 'a', 'name' => null, 'arguments' => ''], + ], + ]; + } + + /** + * @dataProvider streamToolCallDeltaGuardProvider + * + * @param array $delta + * @param array $expected + */ + public function testParseStreamToolCallDeltasGuards(array $delta, array $expected): void + { + $deltas = $this->createModel()->exposeParseStreamToolCallDeltas($delta); + + $this->assertCount($expected['count'], $deltas); + if ($expected['count'] === 0) { + return; + } + + $this->assertSame($expected['index'], $deltas[0]->getIndex()); + $this->assertSame($expected['id'], $deltas[0]->getId()); + $this->assertSame($expected['name'], $deltas[0]->getFunctionName()); + $this->assertSame($expected['arguments'], $deltas[0]->getArgumentsFragment()); + } + + /** + * @return array, 1: int, 2: int, 3: int, 4: int|null}> + */ + public function usageDataProvider(): array + { + return [ + 'full usage with reasoning tokens' => [ + [ + 'prompt_tokens' => 15, + 'completion_tokens' => 20, + 'total_tokens' => 35, + 'completion_tokens_details' => ['reasoning_tokens' => 10], + ], + 15, + 20, + 35, + 10, + ], + 'usage without reasoning tokens' => [ + ['prompt_tokens' => 8, 'completion_tokens' => 4, 'total_tokens' => 12], + 8, + 4, + 12, + null, + ], + 'partial usage defaults missing counts to zero' => [ + ['prompt_tokens' => 20], + 20, + 0, + 0, + null, + ], + 'non-int reasoning tokens are ignored' => [ + [ + 'prompt_tokens' => 1, + 'completion_tokens' => 1, + 'total_tokens' => 2, + 'completion_tokens_details' => ['reasoning_tokens' => 'x'], + ], + 1, + 1, + 2, + null, + ], + 'numeric-string tokens are coerced to int' => [ + ['prompt_tokens' => '15', 'completion_tokens' => '20', 'total_tokens' => '35'], + 15, + 20, + 35, + null, + ], + 'float tokens are coerced to int' => [ + ['prompt_tokens' => 15.0, 'completion_tokens' => 20.0, 'total_tokens' => 35.0], + 15, + 20, + 35, + null, + ], + 'numeric-string reasoning tokens are coerced to int' => [ + [ + 'prompt_tokens' => 1, + 'completion_tokens' => 1, + 'total_tokens' => 2, + 'completion_tokens_details' => ['reasoning_tokens' => '10'], + ], + 1, + 1, + 2, + 10, + ], + ]; + } + + /** + * @dataProvider usageDataProvider + * + * @param array $usage + */ + public function testParseUsageData( + array $usage, + int $prompt, + int $completion, + int $total, + ?int $thought + ): void { + $tokenUsage = $this->createModel()->exposeParseUsageData($usage); + + $this->assertSame($prompt, $tokenUsage->getPromptTokens()); + $this->assertSame($completion, $tokenUsage->getCompletionTokens()); + $this->assertSame($total, $tokenUsage->getTotalTokens()); + $this->assertSame($thought, $tokenUsage->getThoughtTokens()); + } + + /** + * Tests that extractAdditionalData() strips id, choices, and usage. + */ + public function testExtractAdditionalDataStripsKnownKeys(): void + { + $data = $this->createModel()->exposeExtractAdditionalData([ + 'id' => 'x', + 'choices' => [], + 'usage' => [], + 'object' => 'chat.completion.chunk', + 'system_fingerprint' => 'fp_1', + ]); + + $this->assertSame(['object' => 'chat.completion.chunk', 'system_fingerprint' => 'fp_1'], $data); + } } diff --git a/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleTextGenerationModel.php b/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleTextGenerationModel.php index 4ddd3d7f..9d8edbad 100644 --- a/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleTextGenerationModel.php +++ b/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleTextGenerationModel.php @@ -17,6 +17,10 @@ use WordPress\AiClient\Providers\OpenAiCompatibleImplementation\AbstractOpenAiCompatibleTextGenerationModel; use WordPress\AiClient\Results\DTO\Candidate; use WordPress\AiClient\Results\DTO\GenerativeAiResult; +use WordPress\AiClient\Results\DTO\TokenUsage; +use WordPress\AiClient\Results\ValueObjects\CandidateDelta; +use WordPress\AiClient\Results\ValueObjects\GenerativeAiResultChunk; +use WordPress\AiClient\Results\ValueObjects\ToolCallDelta; /** * Mock class for testing AbstractOpenAiCompatibleTextGenerationModel. @@ -177,4 +181,54 @@ public function exposeParseResponseChoiceMessageToolCallPart(array $toolCallData { return $this->parseResponseChoiceMessageToolCallPart($toolCallData); } + + public function exposeThrowIfStreamError(array $event): void + { + $this->throwIfStreamError($event); + } + + public function exposeParseStreamEvent(array $data): ?GenerativeAiResultChunk + { + return $this->parseStreamEvent($data); + } + + public function exposeParseStreamChoice(array $choice): CandidateDelta + { + return $this->parseStreamChoice($choice); + } + + /** + * @param array $delta + * @return list + */ + public function exposeParseStreamDeltaParts(array $delta): array + { + return $this->parseStreamDeltaParts($delta); + } + + /** + * @param array $delta + * @return list + */ + public function exposeParseStreamToolCallDeltas(array $delta): array + { + return $this->parseStreamToolCallDeltas($delta); + } + + /** + * @param array $usage + */ + public function exposeParseUsageData(array $usage): TokenUsage + { + return $this->parseUsageData($usage); + } + + /** + * @param array $data + * @return array + */ + public function exposeExtractAdditionalData(array $data): array + { + return $this->extractAdditionalData($data); + } } diff --git a/tests/unit/Results/ChunkAccumulatorTest.php b/tests/unit/Results/ChunkAccumulatorTest.php new file mode 100644 index 00000000..11a7bba0 --- /dev/null +++ b/tests/unit/Results/ChunkAccumulatorTest.php @@ -0,0 +1,541 @@ +createTestProviderMetadata(), $this->createTestModelMetadata()); + } + + /** + * Creates a content-channel message part. + * + * @param string $text The text. + * @param string|null $signature The thought signature. + * @return MessagePart + */ + private function createContentPart(string $text, ?string $signature = null): MessagePart + { + return new MessagePart($text, MessagePartChannelEnum::content(), $signature); + } + + /** + * Creates a thought-channel message part. + * + * @param string $text The text. + * @param string|null $signature The thought signature. + * @return MessagePart + */ + private function createReasoningPart(string $text, ?string $signature = null): MessagePart + { + return new MessagePart($text, MessagePartChannelEnum::thought(), $signature); + } + + /** + * Creates a chunk. + * + * @param string|null $id The result id. + * @param TokenUsage|null $usage The token usage. + * @param array $additionalData The provider metadata. + * @param list $candidateDeltas The candidate deltas. + * @return GenerativeAiResultChunk + */ + private function createChunk( + ?string $id = null, + ?TokenUsage $usage = null, + array $additionalData = [], + array $candidateDeltas = [] + ): GenerativeAiResultChunk { + return new GenerativeAiResultChunk($id, $usage, $additionalData, $candidateDeltas); + } + + /** + * Tests that the id is captured from the first chunk that reports one. + */ + public function testCapturesIdFromFirstChunkThatReportsOne(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('hi')])])); + $acc->add($this->createChunk('first', null, [], [])); + $acc->add($this->createChunk('second', null, [], [])); + + $this->assertSame('first', $acc->build()->getId()); + } + + /** + * Tests that token usage is last-wins. + */ + public function testTokenUsageIsLastWins(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('hi')])])); + $acc->add($this->createChunk(null, new TokenUsage(1, 1, 2), [], [])); + $acc->add($this->createChunk(null, new TokenUsage(10, 20, 30), [], [])); + + $usage = $acc->build()->getTokenUsage(); + $this->assertSame(10, $usage->getPromptTokens()); + $this->assertSame(30, $usage->getTotalTokens()); + } + + /** + * Tests that additional data is merged with later chunks winning. + */ + public function testAdditionalDataMergedWithLaterChunksWinning(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, ['model' => 'x', 'a' => 1], [ + new CandidateDelta(0, [$this->createContentPart('hi')]), + ])); + $acc->add($this->createChunk(null, null, [], [])); + $acc->add($this->createChunk(null, null, ['model' => 'y', 'b' => 2], [])); + + $this->assertSame(['model' => 'y', 'a' => 1, 'b' => 2], $acc->build()->getAdditionalData()); + } + + /** + * Tests that text on the same channel is concatenated into one part. + */ + public function testTextConcatenatedPerChannel(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('Hel')])])); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('lo')])])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(1, $parts); + $this->assertSame('Hello', $parts[0]->getText()); + $this->assertTrue($parts[0]->getChannel()->is(MessagePartChannelEnum::content())); + } + + /** + * Tests that reasoning and content become separate parts in arrival order. + */ + public function testReasoningAndContentAreSeparatePartsInArrivalOrder(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [$this->createReasoningPart('because'), $this->createContentPart('Hi')]), + ])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(2, $parts); + $this->assertTrue($parts[0]->getChannel()->is(MessagePartChannelEnum::thought())); + $this->assertSame('because', $parts[0]->getText()); + $this->assertTrue($parts[1]->getChannel()->is(MessagePartChannelEnum::content())); + $this->assertSame('Hi', $parts[1]->getText()); + } + + /** + * Tests that the thought signature is captured last-wins and a null signature does not clear it. + */ + public function testThoughtSignatureCapturedLastWinsPerChannel(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [ + $this->createReasoningPart('a', 's1'), + $this->createReasoningPart('b', 's2'), + $this->createReasoningPart('c', null), + ]), + ])); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('hi')])])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertSame('abc', $parts[0]->getText()); + $this->assertSame('s2', $parts[0]->getThoughtSignature()); + $this->assertNull($parts[1]->getThoughtSignature()); + } + + /** + * Tests that the finish reason defaults to stop when not reported. + */ + public function testFinishReasonDefaultsToStopWhenNotReported(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('hi')])])); + + $this->assertTrue($acc->build()->getCandidates()[0]->getFinishReason()->is(FinishReasonEnum::stop())); + } + + /** + * Tests that the reported finish reason is used when present. + */ + public function testFinishReasonUsedWhenReported(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [$this->createContentPart('hi')], FinishReasonEnum::length()), + ])); + + $this->assertTrue($acc->build()->getCandidates()[0]->getFinishReason()->is(FinishReasonEnum::length())); + } + + /** + * Tests that a non-text part is kept and placed after the text parts. + */ + public function testNonTextPartIsKeptAndPlacedAfterText(): void + { + $acc = $this->createAccumulator(); + $functionCallPart = new MessagePart(new FunctionCall('id', 'fn', ['k' => 'v'])); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [$this->createContentPart('Hi'), $functionCallPart]), + ])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(2, $parts); + $this->assertSame('Hi', $parts[0]->getText()); + $this->assertNotNull($parts[1]->getFunctionCall()); + $this->assertSame('fn', $parts[1]->getFunctionCall()->getName()); + } + + /** + * Tests that tool-call fragments are stitched by slot into a function call. + */ + public function testToolCallReassembledFromFragments(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(0, 'call_1', 'get_weather', '{"loc')]), + ])); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(0, null, null, 'ation":"SF"}')]), + ])); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [], FinishReasonEnum::toolCalls())])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(1, $parts); + $fc = $parts[0]->getFunctionCall(); + $this->assertNotNull($fc); + $this->assertSame('call_1', $fc->getId()); + $this->assertSame('get_weather', $fc->getName()); + $this->assertSame(['location' => 'SF'], $fc->getArgs()); + } + + /** + * Tests that tool-call id and name are first-wins across fragments. + */ + public function testToolCallIdAndNameAreFirstWins(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(0, 'first-id', null, '')]), + ])); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(0, 'ignored-id', 'the_fn', '{}')]), + ])); + + $fc = $acc->build()->getCandidates()[0]->getMessage()->getParts()[0]->getFunctionCall(); + $this->assertSame('first-id', $fc->getId()); + $this->assertSame('the_fn', $fc->getName()); + $this->assertSame([], $fc->getArgs()); + } + + /** + * Tests decoding of tool-call arguments. + * + * @dataProvider toolCallArgumentsProvider + * + * @param string $argumentsFragment The accumulated arguments string. + * @param mixed $expected The expected decoded arguments. + */ + public function testDecodesToolCallArguments(string $argumentsFragment, $expected): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(0, 'id', 'fn', $argumentsFragment)]), + ])); + + $fc = $acc->build()->getCandidates()[0]->getMessage()->getParts()[0]->getFunctionCall(); + $this->assertSame($expected, $fc->getArgs()); + } + + /** + * @return array + */ + public function toolCallArgumentsProvider(): array + { + return [ + 'valid JSON decodes to an array' => ['{"city":"SF"}', ['city' => 'SF']], + 'nested JSON decodes recursively' => ['{"a":{"b":1}}', ['a' => ['b' => 1]]], + 'broken JSON is kept as the raw string' => ['{"city":"SF', '{"city":"SF'], + 'empty arguments become null' => ['', null], + ]; + } + + /** + * Tests that parallel tool calls are emitted in slot-index order. + */ + public function testParallelToolCallsAreOrderedBySlot(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [ + new ToolCallDelta(1, 'call_b', 'fn_b', '{"y":2}'), + new ToolCallDelta(0, 'call_a', 'fn_a', '{"x":1}'), + ]), + ])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(2, $parts); + $this->assertSame('fn_a', $parts[0]->getFunctionCall()->getName()); + $this->assertSame('fn_b', $parts[1]->getFunctionCall()->getName()); + } + + /** + * Tests that a tool-call slot without an id or name is skipped. + */ + public function testToolCallSlotWithoutIdOrNameIsSkipped(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [$this->createContentPart('hi')], null, [ + new ToolCallDelta(0, null, null, '{"x":1}'), + ]), + ])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(1, $parts); + $this->assertSame('hi', $parts[0]->getText()); + } + + /** + * Tests that tool-call fragments without an index stitch into slot 0. + */ + public function testToolCallDeltaWithoutIndexUsesSlotZero(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(null, 'id', 'fn', '{"a":')]), + ])); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(null, null, null, '1}')]), + ])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(1, $parts); + $this->assertSame(['a' => 1], $parts[0]->getFunctionCall()->getArgs()); + } + + /** + * Tests that candidates are separated and sorted by index. + */ + public function testCandidatesAreSeparatedAndSortedByIndex(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(1, [$this->createContentPart('B')]), + new CandidateDelta(0, [$this->createContentPart('A')]), + ])); + + $candidates = $acc->build()->getCandidates(); + $this->assertCount(2, $candidates); + $this->assertSame('A', $candidates[0]->getMessage()->getParts()[0]->getText()); + $this->assertSame('B', $candidates[1]->getMessage()->getParts()[0]->getText()); + } + + /** + * Tests that hasCandidates() reflects the accumulated state. + */ + public function testHasCandidatesReflectsAccumulatedState(): void + { + $acc = $this->createAccumulator(); + $this->assertFalse($acc->hasCandidates()); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('hi')])])); + $this->assertTrue($acc->hasCandidates()); + } + + /** + * Tests that a metadata-only chunk registers no candidate. + */ + public function testMetadataOnlyChunkRegistersNoCandidate(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk('id', new TokenUsage(1, 1, 2), ['model' => 'x'], [])); + + $this->assertFalse($acc->hasCandidates()); + } + + /** + * Tests that build() throws when no candidate was accumulated. + */ + public function testBuildThrowsWhenNoCandidates(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk('id', new TokenUsage(1, 1, 2), ['model' => 'x'], [])); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('The stream produced no candidates.'); + $acc->build(); + } + + /** + * Tests that build() applies defaults when metadata is absent. + */ + public function testBuildAppliesDefaultsWhenMetadataAbsent(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('hi')])])); + + $result = $acc->build(); + $this->assertSame('', $result->getId()); + $this->assertSame(0, $result->getTokenUsage()->getTotalTokens()); + $this->assertSame([], $result->getAdditionalData()); + } + + /** + * Tests that build() uses the provider and model metadata. + */ + public function testBuildUsesProviderAndModelMetadata(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('hi')])])); + + $result = $acc->build(); + $this->assertSame('test-provider', $result->getProviderMetadata()->getId()); + $this->assertSame('test-model', $result->getModelMetadata()->getId()); + } + + /** + * Tests that a candidate with no parts builds an empty message. + */ + public function testCandidateWithNoPartsBuildsEmptyMessage(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [], FinishReasonEnum::contentFilter())])); + + $candidate = $acc->build()->getCandidates()[0]; + $this->assertCount(0, $candidate->getMessage()->getParts()); + $this->assertTrue($candidate->getFinishReason()->is(FinishReasonEnum::contentFilter())); + } + + /** + * Tests that tool-call arguments reassemble from many small fragments. + */ + public function testToolCallArgumentsReassembleFromManySmallFragments(): void + { + $acc = $this->createAccumulator(); + $pieces = ['{', '"ci', 'ty":', ' "San ', 'Francisco"', '}']; + foreach ($pieces as $i => $piece) { + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [ + new ToolCallDelta(0, $i === 0 ? 'call_1' : null, $i === 0 ? 'get_weather' : null, $piece), + ]), + ])); + } + + $fc = $acc->build()->getCandidates()[0]->getMessage()->getParts()[0]->getFunctionCall(); + $this->assertSame('get_weather', $fc->getName()); + $this->assertSame(['city' => 'San Francisco'], $fc->getArgs()); + } + + /** + * Tests that parallel tool-call fragments interleaved across deltas stitch per slot. + */ + public function testParallelToolCallFragmentsInterleaveAcrossDeltas(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [ + new ToolCallDelta(0, 'a', 'fn_a', '{"x":'), + new ToolCallDelta(1, 'b', 'fn_b', '{"y":'), + ]), + ])); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(0, null, null, '1}')]), + ])); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(1, null, null, '2}')]), + ])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(2, $parts); + $this->assertSame('fn_a', $parts[0]->getFunctionCall()->getName()); + $this->assertSame(['x' => 1], $parts[0]->getFunctionCall()->getArgs()); + $this->assertSame('fn_b', $parts[1]->getFunctionCall()->getName()); + $this->assertSame(['y' => 2], $parts[1]->getFunctionCall()->getArgs()); + } + + /** + * Tests that interleaved reasoning and content concatenate per channel. + */ + public function testReasoningAndContentInterleaveAcrossDeltas(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createReasoningPart('think ')])])); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('Hi ')])])); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createReasoningPart('more')])])); + $acc->add($this->createChunk(null, null, [], [new CandidateDelta(0, [$this->createContentPart('there')])])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(2, $parts); + $this->assertTrue($parts[0]->getChannel()->is(MessagePartChannelEnum::thought())); + $this->assertSame('think more', $parts[0]->getText()); + $this->assertTrue($parts[1]->getChannel()->is(MessagePartChannelEnum::content())); + $this->assertSame('Hi there', $parts[1]->getText()); + } + + /** + * Tests that parts are ordered thought-then-content even when content arrives first. + */ + public function testContentArrivingBeforeReasoningStillOrdersThoughtFirst(): void + { + $acc = $this->createAccumulator(); + $acc->add($this->createChunk(null, null, [], [ + new CandidateDelta(0, [$this->createContentPart('answer'), $this->createReasoningPart('reason')]), + ])); + + $parts = $acc->build()->getCandidates()[0]->getMessage()->getParts(); + $this->assertCount(2, $parts); + $this->assertTrue($parts[0]->getChannel()->is(MessagePartChannelEnum::thought())); + $this->assertSame('reason', $parts[0]->getText()); + $this->assertTrue($parts[1]->getChannel()->is(MessagePartChannelEnum::content())); + $this->assertSame('answer', $parts[1]->getText()); + } +} diff --git a/tests/unit/Results/StreamedGenerativeAiResultTest.php b/tests/unit/Results/StreamedGenerativeAiResultTest.php new file mode 100644 index 00000000..e8d0ceb3 --- /dev/null +++ b/tests/unit/Results/StreamedGenerativeAiResultTest.php @@ -0,0 +1,682 @@ + $chunks The chunk stream. + * @return StreamedGenerativeAiResult + */ + private function createHandle(Iterator $chunks): StreamedGenerativeAiResult + { + return new StreamedGenerativeAiResult( + $chunks, + $this->createTestProviderMetadata(), + $this->createTestModelMetadata() + ); + } + + /** + * Creates a handle over a fixed list of chunks. + * + * @param list $chunks The chunks. + * @return StreamedGenerativeAiResult + */ + private function createHandleFromChunks(array $chunks): StreamedGenerativeAiResult + { + return $this->createHandle(new ArrayIterator($chunks)); + } + + /** + * Creates a content chunk for candidate 0. + * + * @param string $text The content text delta. + * @param FinishReasonEnum|null $finishReason The finish reason, if any. + * @return GenerativeAiResultChunk + */ + private function createContentChunk(string $text, ?FinishReasonEnum $finishReason = null): GenerativeAiResultChunk + { + return new GenerativeAiResultChunk(null, null, [], [ + new CandidateDelta(0, [new MessagePart($text)], $finishReason), + ]); + } + + /** + * Creates a metadata-only chunk carrying token usage. + * + * @param TokenUsage $usage The usage. + * @return GenerativeAiResultChunk + */ + private function createUsageChunk(TokenUsage $usage): GenerativeAiResultChunk + { + return new GenerativeAiResultChunk(null, $usage, [], []); + } + + /** + * Yields the given chunks, then throws. + * + * @param list $chunks The chunks to yield before failing. + * @param \Throwable $error The error to throw after the chunks. + * @return Generator + */ + private function createFailingIterator(array $chunks, \Throwable $error): Generator + { + foreach ($chunks as $chunk) { + yield $chunk; + } + throw $error; + } + + /** + * Creates a single-use source that throws if it is read after exhaustion. + * + * @param list $chunks The chunks to yield once. + * @return Iterator + */ + private function createSingleUseSource(array $chunks): Iterator + { + return new class ($chunks) implements Iterator { + /** @var list */ + private array $chunks; + private int $pos = 0; + private bool $exhausted = false; + + /** @param list $chunks */ + public function __construct(array $chunks) + { + $this->chunks = array_values($chunks); + } + + public function current(): GenerativeAiResultChunk + { + $this->guardNotExhausted(); + return $this->chunks[$this->pos]; + } + + public function next(): void + { + $this->guardNotExhausted(); + $this->pos++; + if ($this->pos >= count($this->chunks)) { + $this->exhausted = true; + } + } + + public function key(): int + { + return $this->pos; + } + + public function valid(): bool + { + return $this->pos < count($this->chunks); + } + + public function rewind(): void + { + if ($this->pos !== 0 || $this->exhausted) { + throw new \LogicException('A single-use source cannot be rewound.'); + } + } + + private function guardNotExhausted(): void + { + if ($this->exhausted) { + throw new \LogicException('A single-use source cannot be read after exhaustion.'); + } + } + }; + } + + /** + * Tests that iterating yields all chunks in order. + */ + public function testIteratingYieldsAllChunksInOrder(): void + { + $a = $this->createContentChunk('a'); + $b = $this->createContentChunk('b'); + $c = $this->createContentChunk('c'); + + $collected = []; + foreach ($this->createHandleFromChunks([$a, $b, $c]) as $chunk) { + $collected[] = $chunk; + } + + $this->assertSame([$a, $b, $c], $collected); + } + + /** + * Tests that getFinalResult() assembles the result without iterating. + */ + public function testGetFinalResultAssemblesWithoutIterating(): void + { + $result = $this->createHandleFromChunks([ + $this->createContentChunk('Hel'), + $this->createContentChunk('lo', FinishReasonEnum::stop()), + $this->createUsageChunk(new TokenUsage(3, 5, 8)), + ])->getFinalResult(); + + $this->assertInstanceOf(GenerativeAiResult::class, $result); + $this->assertSame('Hello', $result->toText()); + $this->assertSame(8, $result->getTokenUsage()->getTotalTokens()); + $this->assertTrue($result->getCandidates()[0]->getFinishReason()->is(FinishReasonEnum::stop())); + } + + /** + * Tests that iterating then getFinalResult() returns the same assembled result. + */ + public function testIterateThenGetFinalResultIsConsistent(): void + { + $handle = $this->createHandleFromChunks([ + $this->createContentChunk('Hel'), + $this->createContentChunk('lo', FinishReasonEnum::stop()), + ]); + + foreach ($handle as $chunk) { + // Consume the stream but do nothing with the chunks. + } + + $this->assertSame('Hello', $handle->getFinalResult()->toText()); + } + + /** + * Tests that getFinalResult() returns the same instance on repeated calls. + */ + public function testGetFinalResultIsIdempotent(): void + { + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + + $first = $handle->getFinalResult(); + $second = $handle->getFinalResult(); + + $this->assertSame($first, $second); + } + + /** + * Tests that getFinalResult() drains the remainder after an early break. + */ + public function testGetFinalResultAfterEarlyBreakDrainsRemainder(): void + { + $handle = $this->createHandleFromChunks([ + $this->createContentChunk('a'), + $this->createContentChunk('b'), + $this->createContentChunk('c', FinishReasonEnum::stop()), + ]); + + foreach ($handle as $chunk) { + break; + } + + $this->assertSame('abc', $handle->getFinalResult()->toText()); + } + + /** + * Tests that the completion callback fires once with the result after a full iteration. + */ + public function testOnCompleteFiresOnceWithResultOnFullIteration(): void + { + $received = []; + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->onComplete(function (GenerativeAiResult $result) use (&$received) { + $received[] = $result; + }); + + foreach ($handle as $chunk) { + // Consume the stream but do nothing with the chunks. + } + + $this->assertCount(1, $received); + $this->assertSame($handle->getFinalResult(), $received[0]); + } + + /** + * Tests that the completion callback fires on getFinalResult() without iterating. + */ + public function testOnCompleteFiresOnGetFinalResult(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->onComplete(function () use (&$count) { + $count++; + }); + + $handle->getFinalResult(); + + $this->assertSame(1, $count); + } + + /** + * Tests that the completion callback fires only once across iteration and getFinalResult(). + */ + public function testOnCompleteFiresOnlyOnceAcrossIterateAndGetFinalResult(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->onComplete(function () use (&$count) { + $count++; + }); + + foreach ($handle as $chunk) { + // Consume the stream but do nothing with the chunks. + } + $handle->getFinalResult(); + + $this->assertSame(1, $count); + } + + /** + * Tests that multiple completion callbacks fire in registration order. + */ + public function testMultipleOnCompleteCallbacksFireInRegistrationOrder(): void + { + $order = []; + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->onComplete(function () use (&$order) { + $order[] = 'first'; + }); + $handle->onComplete(function () use (&$order) { + $order[] = 'second'; + }); + + $handle->getFinalResult(); + + $this->assertSame(['first', 'second'], $order); + } + + /** + * Tests that the completion callback does not fire on an early break. + */ + public function testOnCompleteNotFiredOnEarlyBreak(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([ + $this->createContentChunk('a'), + $this->createContentChunk('b', FinishReasonEnum::stop()), + ]); + $handle->onComplete(function () use (&$count) { + $count++; + }); + + foreach ($handle as $chunk) { + break; + } + + $this->assertSame(0, $count); + } + + /** + * Tests that an empty stream throws on getFinalResult() and fires no completion callback. + */ + public function testEmptyStreamThrowsAndDoesNotFireOnComplete(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([$this->createUsageChunk(new TokenUsage(1, 1, 2))]); + $handle->onComplete(function () use (&$count) { + $count++; + }); + + try { + $handle->getFinalResult(); + $this->fail('Expected RuntimeException for a stream with no candidates.'); + } catch (RuntimeException $e) { + $this->assertStringContainsString('no candidates', $e->getMessage()); + } + + $this->assertSame(0, $count); + } + + /** + * Tests that getFinalResult() after iterating an empty stream throws without re-reading the source. + */ + public function testGetFinalResultAfterIteratingEmptyStreamThrowsWithoutReReadingSource(): void + { + $count = 0; + $handle = $this->createHandle( + $this->createSingleUseSource([$this->createUsageChunk(new TokenUsage(1, 1, 2))]) + ); + $handle->onComplete(function () use (&$count) { + $count++; + }); + + foreach ($handle as $chunk) { + // Consume the stream but do nothing with the chunks. + } + + try { + $handle->getFinalResult(); + $this->fail('Expected RuntimeException for a stream with no candidates.'); + } catch (RuntimeException $e) { + $this->assertStringContainsString('no candidates', $e->getMessage()); + } + + $this->assertSame(0, $count); + } + + /** + * Tests that a stream without a finish reason defaults to stop. + */ + public function testPartialStreamWithoutFinishReasonDefaultsToStop(): void + { + $result = $this->createHandleFromChunks([$this->createContentChunk('partial')])->getFinalResult(); + + $this->assertSame('partial', $result->toText()); + $this->assertTrue($result->getCandidates()[0]->getFinishReason()->is(FinishReasonEnum::stop())); + } + + /** + * Tests that iteration yields all chunks produced before an error, then propagates it. + */ + public function testIterationYieldsChunksBeforeAnErrorThenPropagates(): void + { + $a = $this->createContentChunk('a'); + $b = $this->createContentChunk('b'); + $handle = $this->createHandle( + $this->createFailingIterator([$a, $b], new RuntimeException('stream failed')) + ); + + $collected = []; + $thrown = null; + try { + foreach ($handle as $chunk) { + $collected[] = $chunk; + } + } catch (RuntimeException $e) { + $thrown = $e; + } + + $this->assertSame([$a, $b], $collected); + $this->assertNotNull($thrown); + $this->assertSame('stream failed', $thrown->getMessage()); + } + + /** + * Tests that getFinalResult() propagates an error raised while draining. + */ + public function testGetFinalResultPropagatesStreamError(): void + { + $handle = $this->createHandle( + $this->createFailingIterator([$this->createContentChunk('a')], new RuntimeException('stream failed')) + ); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('stream failed'); + $handle->getFinalResult(); + } + + /** + * Tests that the completion callback does not fire when the stream errors. + */ + public function testOnCompleteNotFiredWhenStreamErrors(): void + { + $count = 0; + $handle = $this->createHandle( + $this->createFailingIterator([$this->createContentChunk('a')], new RuntimeException('stream failed')) + ); + $handle->onComplete(function () use (&$count) { + $count++; + }); + + try { + foreach ($handle as $chunk) { + // Consume the stream but do nothing with the chunks. + } + } catch (RuntimeException $e) { + // Expected error here. + } + + $this->assertSame(0, $count); + } + + /** + * Tests that re-iterating a consumed stream throws. + */ + public function testReiteratingAConsumedStreamThrows(): void + { + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + + foreach ($handle as $chunk) { + // Consume the stream but do nothing with the chunks. + } + + $this->expectException(RuntimeException::class); + foreach ($handle as $chunk) { + } + } + + /** + * Tests that iterating after getFinalResult() throws. + */ + public function testIteratingAfterGetFinalResultThrows(): void + { + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->getFinalResult(); + + $this->expectException(RuntimeException::class); + foreach ($handle as $chunk) { + } + } + + /** + * Tests that the start callback fires exactly once when consumption begins. + */ + public function testOnStartFiresOnceWhenConsumptionBegins(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->onStart(function () use (&$count) { + $count++; + }); + + foreach ($handle as $chunk) { + } + + $this->assertSame(1, $count); + } + + /** + * Tests that the start callback does not fire until the stream is consumed. + */ + public function testOnStartNotFiredUntilConsumed(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->onStart(function () use (&$count) { + $count++; + }); + + $this->assertSame(0, $count); + + $handle->getFinalResult(); + + $this->assertSame(1, $count); + } + + /** + * Tests that the start callback fires even when the caller breaks early. + */ + public function testOnStartFiresOnEarlyBreak(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([ + $this->createContentChunk('a'), + $this->createContentChunk('b', FinishReasonEnum::stop()), + ]); + $handle->onStart(function () use (&$count) { + $count++; + }); + + foreach ($handle as $chunk) { + break; + } + + $this->assertSame(1, $count); + } + + /** + * Tests that the error callback fires once with the error when iterating fails. + */ + public function testOnErrorFiresOnceWhenIterationFails(): void + { + $received = []; + $handle = $this->createHandle( + $this->createFailingIterator([$this->createContentChunk('a')], new RuntimeException('stream failed')) + ); + $handle->onError(function (\Throwable $e) use (&$received) { + $received[] = $e; + }); + + try { + foreach ($handle as $chunk) { + } + } catch (RuntimeException $e) { + } + + $this->assertCount(1, $received); + $this->assertSame('stream failed', $received[0]->getMessage()); + } + + /** + * Tests that the error callback fires when getFinalResult() drains a failing stream. + */ + public function testOnErrorFiresWhenGetFinalResultFails(): void + { + $count = 0; + $handle = $this->createHandle( + $this->createFailingIterator([$this->createContentChunk('a')], new RuntimeException('stream failed')) + ); + $handle->onError(function () use (&$count) { + $count++; + }); + + try { + $handle->getFinalResult(); + } catch (RuntimeException $e) { + } + + $this->assertSame(1, $count); + } + + /** + * Tests that the error callback does not fire on a successful stream. + */ + public function testOnErrorNotFiredOnSuccess(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([$this->createContentChunk('hi', FinishReasonEnum::stop())]); + $handle->onError(function () use (&$count) { + $count++; + }); + + $handle->getFinalResult(); + + $this->assertSame(0, $count); + } + + /** + * Tests that the error callback does not fire when the caller breaks early. + */ + public function testOnErrorNotFiredOnEarlyBreak(): void + { + $count = 0; + $handle = $this->createHandleFromChunks([ + $this->createContentChunk('a'), + $this->createContentChunk('b', FinishReasonEnum::stop()), + ]); + $handle->onError(function () use (&$count) { + $count++; + }); + + foreach ($handle as $chunk) { + break; + } + + $this->assertSame(0, $count); + } + + /** + * Tests that getFinalResult() after a failure re-throws the original stream error. + */ + public function testGetFinalResultAfterFailureRethrowsOriginalError(): void + { + $handle = $this->createHandle( + $this->createFailingIterator([$this->createContentChunk('a')], new RuntimeException('stream failed')) + ); + + try { + foreach ($handle as $chunk) { + } + } catch (RuntimeException $e) { + } + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('stream failed'); + $handle->getFinalResult(); + } + + /** + * Tests that the error callback fires at most once across iteration and getFinalResult(). + */ + public function testOnErrorFiresAtMostOnceAcrossIterationAndGetFinalResult(): void + { + $count = 0; + $handle = $this->createHandle( + $this->createFailingIterator([$this->createContentChunk('a')], new RuntimeException('stream failed')) + ); + $handle->onError(function () use (&$count) { + $count++; + }); + + try { + foreach ($handle as $chunk) { + } + } catch (RuntimeException $e) { + } + + try { + $handle->getFinalResult(); + } catch (RuntimeException $e) { + } + + $this->assertSame(1, $count); + } +} diff --git a/tests/unit/Results/ValueObjects/CandidateDeltaTest.php b/tests/unit/Results/ValueObjects/CandidateDeltaTest.php new file mode 100644 index 00000000..ac665b8a --- /dev/null +++ b/tests/unit/Results/ValueObjects/CandidateDeltaTest.php @@ -0,0 +1,124 @@ +createContentPart('Hi')]; + $toolCallDeltas = [new ToolCallDelta(0, 'call_1', 'fn', '{}')]; + $delta = new CandidateDelta(1, $parts, FinishReasonEnum::stop(), $toolCallDeltas); + + $this->assertSame(1, $delta->getIndex()); + $this->assertSame($parts, $delta->getParts()); + $this->assertTrue($delta->getFinishReason()->is(FinishReasonEnum::stop())); + $this->assertSame($toolCallDeltas, $delta->getToolCallDeltas()); + } + + /** + * Tests that the optional fields default to empty values and a null finish reason. + */ + public function testDefaults(): void + { + $delta = new CandidateDelta(0); + + $this->assertSame([], $delta->getParts()); + $this->assertNull($delta->getFinishReason()); + $this->assertSame([], $delta->getToolCallDeltas()); + } + + /** + * Tests that getDeltaText concatenates only the content-channel parts. + */ + public function testGetDeltaTextConcatenatesContentChannelOnly(): void + { + $delta = new CandidateDelta(0, [$this->createContentPart('Hel'), $this->createContentPart('lo')]); + + $this->assertSame('Hello', $delta->getDeltaText()); + } + + /** + * Tests that getReasoningDeltaText concatenates only the thought-channel parts. + */ + public function testGetReasoningDeltaTextConcatenatesThoughtChannelOnly(): void + { + $delta = new CandidateDelta(0, [$this->createReasoningPart('Think'), $this->createReasoningPart('ing')]); + + $this->assertSame('Thinking', $delta->getReasoningDeltaText()); + } + + /** + * Tests that the content and reasoning channels do not bleed into each other. + */ + public function testContentAndReasoningChannelsDoNotBleed(): void + { + $delta = new CandidateDelta(0, [$this->createReasoningPart('reason'), $this->createContentPart('answer')]); + + $this->assertSame('answer', $delta->getDeltaText()); + $this->assertSame('reason', $delta->getReasoningDeltaText()); + } + + /** + * Tests that a non-text content part is ignored when reading the delta text. + */ + public function testNonTextPartIsIgnoredInDeltaText(): void + { + $delta = new CandidateDelta(0, [ + $this->createContentPart('a'), + new MessagePart(new FunctionCall('id', 'fn', [])), + $this->createContentPart('b'), + ]); + + $this->assertSame('ab', $delta->getDeltaText()); + } + + /** + * Tests that a candidate with no parts yields empty delta text. + */ + public function testEmptyDeltaTextWhenNoParts(): void + { + $delta = new CandidateDelta(0); + + $this->assertSame('', $delta->getDeltaText()); + $this->assertSame('', $delta->getReasoningDeltaText()); + } +} diff --git a/tests/unit/Results/ValueObjects/GenerativeAiResultChunkTest.php b/tests/unit/Results/ValueObjects/GenerativeAiResultChunkTest.php new file mode 100644 index 00000000..9487acf4 --- /dev/null +++ b/tests/unit/Results/ValueObjects/GenerativeAiResultChunkTest.php @@ -0,0 +1,140 @@ +createContentPart('Hi')])]; + $chunk = new GenerativeAiResultChunk('chatcmpl-1', $usage, ['model' => 'x'], $candidateDeltas); + + $this->assertSame('chatcmpl-1', $chunk->getId()); + $this->assertSame($usage, $chunk->getTokenUsage()); + $this->assertSame(['model' => 'x'], $chunk->getAdditionalData()); + $this->assertSame($candidateDeltas, $chunk->getCandidateDeltas()); + } + + /** + * Tests that the optional fields default to null and empty values. + */ + public function testDefaults(): void + { + $chunk = new GenerativeAiResultChunk(); + + $this->assertNull($chunk->getId()); + $this->assertNull($chunk->getTokenUsage()); + $this->assertSame([], $chunk->getAdditionalData()); + $this->assertSame([], $chunk->getCandidateDeltas()); + } + + /** + * Tests that getDeltaText returns the primary candidate's content by default. + */ + public function testGetDeltaTextReturnsPrimaryCandidateByDefault(): void + { + $chunk = new GenerativeAiResultChunk(null, null, [], [ + new CandidateDelta(0, [$this->createContentPart('A')]), + new CandidateDelta(1, [$this->createContentPart('B')]), + ]); + + $this->assertSame('A', $chunk->getDeltaText()); + $this->assertSame('B', $chunk->getDeltaText(1)); + $this->assertSame('', $chunk->getDeltaText(99)); + } + + /** + * Tests that getReasoningDeltaText returns the primary candidate's reasoning by default. + */ + public function testGetReasoningDeltaTextReturnsPrimaryCandidateByDefault(): void + { + $chunk = new GenerativeAiResultChunk(null, null, [], [ + new CandidateDelta(0, [$this->createReasoningPart('think')]), + new CandidateDelta(1, [$this->createReasoningPart('more')]), + ]); + + $this->assertSame('think', $chunk->getReasoningDeltaText()); + $this->assertSame('more', $chunk->getReasoningDeltaText(1)); + $this->assertSame('', $chunk->getReasoningDeltaText(99)); + } + + /** + * Tests that getToolCallDeltas flattens the tool calls across candidate deltas. + */ + public function testGetToolCallDeltasFlattensAcrossCandidateDeltas(): void + { + $chunk = new GenerativeAiResultChunk(null, null, [], [ + new CandidateDelta(0, [], null, [new ToolCallDelta(0, 'a', 'fn_a', '')]), + new CandidateDelta(1, [], null, [new ToolCallDelta(0, 'b', 'fn_b', '')]), + ]); + + $deltas = $chunk->getToolCallDeltas(); + $this->assertCount(2, $deltas); + $this->assertSame('a', $deltas[0]->getId()); + $this->assertSame('b', $deltas[1]->getId()); + } + + /** + * Tests that a metadata-only chunk has empty convenience accessors. + */ + public function testConveniencesEmptyWhenMetadataOnly(): void + { + $chunk = new GenerativeAiResultChunk('id', new TokenUsage(1, 1, 2), ['model' => 'x'], []); + + $this->assertSame('', $chunk->getDeltaText()); + $this->assertSame('', $chunk->getReasoningDeltaText()); + $this->assertSame([], $chunk->getToolCallDeltas()); + } + + /** + * Tests that the content and reasoning channels do not bleed at the chunk level. + */ + public function testContentAndReasoningDoNotBleed(): void + { + $chunk = new GenerativeAiResultChunk(null, null, [], [ + new CandidateDelta(0, [$this->createReasoningPart('reason'), $this->createContentPart('answer')]), + ]); + + $this->assertSame('answer', $chunk->getDeltaText()); + $this->assertSame('reason', $chunk->getReasoningDeltaText()); + } +} diff --git a/tests/unit/Results/ValueObjects/ToolCallDeltaTest.php b/tests/unit/Results/ValueObjects/ToolCallDeltaTest.php new file mode 100644 index 00000000..43a21f3a --- /dev/null +++ b/tests/unit/Results/ValueObjects/ToolCallDeltaTest.php @@ -0,0 +1,50 @@ +assertSame(2, $delta->getIndex()); + $this->assertSame('call_1', $delta->getId()); + $this->assertSame('get_weather', $delta->getFunctionName()); + $this->assertSame('{"city":"SF"}', $delta->getArgumentsFragment()); + } + + /** + * Tests that the optional fields default to null and an empty fragment. + */ + public function testDefaultsWhenOnlyIndexProvided(): void + { + $delta = new ToolCallDelta(0); + + $this->assertSame(0, $delta->getIndex()); + $this->assertNull($delta->getId()); + $this->assertNull($delta->getFunctionName()); + $this->assertSame('', $delta->getArgumentsFragment()); + } + + /** + * Tests that a null index is preserved. + */ + public function testNullIndexIsPreserved(): void + { + $delta = new ToolCallDelta(null, 'call_1', 'fn', '{}'); + + $this->assertNull($delta->getIndex()); + } +}