@@ -167,17 +167,74 @@ import Foundation
167167 includeSchemaInPrompt: Bool ,
168168 options: GenerationOptions
169169 ) -> sending LanguageModelSession. ResponseStream < Content > where Content: Generable {
170- // For now, only String is supported
171170 guard type == String . self else {
172171 fatalError ( " MLXLanguageModel only supports generating String content " )
173172 }
174173
175- // Streaming API in AnyLanguageModel currently yields once; return an empty snapshot
176- let empty = " "
177- return LanguageModelSession . ResponseStream (
178- content: empty as! Content ,
179- rawContent: GeneratedContent ( empty)
180- )
174+ let modelId = self . modelId
175+ let hub = self . hub
176+ let directory = self . directory
177+
178+ let stream : AsyncThrowingStream < LanguageModelSession . ResponseStream < Content > . Snapshot , any Error > = . init { continuation in
179+ let task = Task { @Sendable in
180+ do {
181+ let context : ModelContext
182+ if let directory {
183+ context = try await loadModel ( directory: directory)
184+ } else if let hub {
185+ context = try await loadModel ( hub: hub, id: modelId)
186+ } else {
187+ context = try await loadModel ( id: modelId)
188+ }
189+
190+ let generateParameters = toGenerateParameters ( options)
191+
192+ var chat : [ MLXLMCommon . Chat . Message ] = [ ]
193+
194+ if let instructionSegments = extractInstructionSegments ( from: session) {
195+ chat. append ( convertSegmentsToMLXSystemMessage ( instructionSegments) )
196+ }
197+
198+ let userSegments = extractPromptSegments ( from: session, fallbackText: prompt. description)
199+ chat. append ( convertSegmentsToMLXMessage ( userSegments) )
200+
201+ let userInput = MLXLMCommon . UserInput (
202+ chat: chat,
203+ processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
204+ tools: nil
205+ )
206+ let lmInput = try await context. processor. prepare ( input: userInput)
207+
208+ let mlxStream = try MLXLMCommon . generate (
209+ input: lmInput,
210+ parameters: generateParameters,
211+ context: context
212+ )
213+
214+ var accumulatedText = " "
215+ for await item in mlxStream {
216+ if Task . isCancelled { break }
217+
218+ switch item {
219+ case . chunk( let text) :
220+ accumulatedText += text
221+ let raw = GeneratedContent ( accumulatedText)
222+ let content : Content . PartiallyGenerated = ( accumulatedText as! Content ) . asPartiallyGenerated ( )
223+ continuation. yield ( . init( content: content, rawContent: raw) )
224+ case . info, . toolCall:
225+ break
226+ }
227+ }
228+
229+ continuation. finish ( )
230+ } catch {
231+ continuation. finish ( throwing: error)
232+ }
233+ }
234+ continuation. onTermination = { _ in task. cancel ( ) }
235+ }
236+
237+ return LanguageModelSession . ResponseStream ( stream: stream)
181238 }
182239 }
183240
0 commit comments