Skip to content

Commit 8a15c4d

Browse files
XMLHexagrammattt
andauthored
Implement streaming response for MLXLanguageModel (#64)
* feat: Impl stream Response for MLXLanguageModel * Serialize MLX tests * Add minimal test coverage for streaming response --------- Co-authored-by: Mattt Zmuda <mattt@me.com>
1 parent 9700936 commit 8a15c4d

2 files changed

Lines changed: 78 additions & 8 deletions

File tree

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,74 @@ import Foundation
167167
includeSchemaInPrompt: Bool,
168168
options: GenerationOptions
169169
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
170-
// For now, only String is supported
171170
guard type == String.self else {
172171
fatalError("MLXLanguageModel only supports generating String content")
173172
}
174173

175-
// Streaming API in AnyLanguageModel currently yields once; return an empty snapshot
176-
let empty = ""
177-
return LanguageModelSession.ResponseStream(
178-
content: empty as! Content,
179-
rawContent: GeneratedContent(empty)
180-
)
174+
let modelId = self.modelId
175+
let hub = self.hub
176+
let directory = self.directory
177+
178+
let stream: AsyncThrowingStream<LanguageModelSession.ResponseStream<Content>.Snapshot, any Error> = .init { continuation in
179+
let task = Task { @Sendable in
180+
do {
181+
let context: ModelContext
182+
if let directory {
183+
context = try await loadModel(directory: directory)
184+
} else if let hub {
185+
context = try await loadModel(hub: hub, id: modelId)
186+
} else {
187+
context = try await loadModel(id: modelId)
188+
}
189+
190+
let generateParameters = toGenerateParameters(options)
191+
192+
var chat: [MLXLMCommon.Chat.Message] = []
193+
194+
if let instructionSegments = extractInstructionSegments(from: session) {
195+
chat.append(convertSegmentsToMLXSystemMessage(instructionSegments))
196+
}
197+
198+
let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
199+
chat.append(convertSegmentsToMLXMessage(userSegments))
200+
201+
let userInput = MLXLMCommon.UserInput(
202+
chat: chat,
203+
processing: .init(resize: .init(width: 512, height: 512)),
204+
tools: nil
205+
)
206+
let lmInput = try await context.processor.prepare(input: userInput)
207+
208+
let mlxStream = try MLXLMCommon.generate(
209+
input: lmInput,
210+
parameters: generateParameters,
211+
context: context
212+
)
213+
214+
var accumulatedText = ""
215+
for await item in mlxStream {
216+
if Task.isCancelled { break }
217+
218+
switch item {
219+
case .chunk(let text):
220+
accumulatedText += text
221+
let raw = GeneratedContent(accumulatedText)
222+
let content: Content.PartiallyGenerated = (accumulatedText as! Content).asPartiallyGenerated()
223+
continuation.yield(.init(content: content, rawContent: raw))
224+
case .info, .toolCall:
225+
break
226+
}
227+
}
228+
229+
continuation.finish()
230+
} catch {
231+
continuation.finish(throwing: error)
232+
}
233+
}
234+
continuation.onTermination = { _ in task.cancel() }
235+
}
236+
237+
return LanguageModelSession.ResponseStream(stream: stream)
181238
}
182239
}
183240

Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import Testing
2929
return false
3030
}()
3131

32-
@Suite("MLXLanguageModel", .enabled(if: shouldRunMLXTests))
32+
@Suite("MLXLanguageModel", .enabled(if: shouldRunMLXTests), .serialized)
3333
struct MLXLanguageModelTests {
3434
// Qwen3-0.6B is a small model that supports tool calling
3535
let model = MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit")
@@ -42,6 +42,19 @@ import Testing
4242
#expect(!response.content.isEmpty)
4343
}
4444

45+
@Test func streamingResponse() async throws {
46+
let session = LanguageModelSession(model: model)
47+
48+
let stream = session.streamResponse(to: "Count to 5")
49+
var chunks: [String] = []
50+
51+
for try await response in stream {
52+
chunks.append(response.content)
53+
}
54+
55+
#expect(!chunks.isEmpty)
56+
}
57+
4558
@Test func withGenerationOptions() async throws {
4659
let session = LanguageModelSession(model: model)
4760

0 commit comments

Comments
 (0)