1- from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Optional , Union
1+ from typing import (
2+ TYPE_CHECKING ,
3+ Any ,
4+ AsyncIterator ,
5+ Dict ,
6+ Iterator ,
7+ List ,
8+ Optional ,
9+ Union ,
10+ )
211
312from typing_extensions import Unpack
413
@@ -59,7 +68,7 @@ async def async_run(
5968 ref : Union ["Model" , "Version" , "ModelVersionIdentifier" , str ],
6069 input : Optional [Dict [str , Any ]] = None ,
6170 ** params : Unpack ["Predictions.CreatePredictionParams" ],
62- ) -> Union [Any , Iterator [Any ]]: # noqa: ANN401
71+ ) -> Union [Any , AsyncIterator [Any ]]: # noqa: ANN401
6372 """
6473 Run a model and wait for its output asynchronously.
6574 """
@@ -82,7 +91,7 @@ async def async_run(
8291 if not version and (owner and name and version_id ):
8392 version = await Versions (client , model = (owner , name )).async_get (version_id )
8493
85- if version and (iterator := _make_output_iterator (version , prediction )):
94+ if version and (iterator := _make_async_output_iterator (version , prediction )):
8695 return iterator
8796
8897 await prediction .async_wait ()
@@ -93,17 +102,32 @@ async def async_run(
93102 return prediction .output
94103
95104
96- def _make_output_iterator (
97- version : Version , prediction : Prediction
98- ) -> Optional [Iterator [Any ]]:
105+ def _has_output_iterator_array_type (version : Version ) -> bool :
99106 schema = make_schema_backwards_compatible (
100107 version .openapi_schema , version .cog_version
101108 )
102- output = schema ["components" ]["schemas" ]["Output" ]
103- if output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator" :
109+ output = schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
110+ return (
111+ output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator"
112+ )
113+
114+
115+ def _make_output_iterator (
116+ version : Version , prediction : Prediction
117+ ) -> Optional [Iterator [Any ]]:
118+ if _has_output_iterator_array_type (version ):
104119 return prediction .output_iterator ()
105120
106121 return None
107122
108123
124+ def _make_async_output_iterator (
125+ version : Version , prediction : Prediction
126+ ) -> Optional [AsyncIterator [Any ]]:
127+ if _has_output_iterator_array_type (version ):
128+ return prediction .async_output_iterator ()
129+
130+ return None
131+
132+
109133__all__ : List = []
0 commit comments