From d65dc4eb5f6e154a27e01a3700165bb64c53dff6 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 7 Feb 2026 20:54:01 -0500 Subject: [PATCH 01/11] add dummy tools --- effectful/handlers/llm/completions.py | 54 ++++++++++++++++++++++++--- effectful/handlers/llm/template.py | 32 ++++++++++++++-- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index e7616fbfd..f2fd85f22 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -260,15 +260,19 @@ def call_assistant[T, U]( includes the raw assistant message for retry handling. """ tool_specs = {k: _function_model(t) for k, t in tools.items()} - response_model = pydantic.create_model( - "Response", value=response_format.enc, __config__={"extra": "forbid"} + response_model = ( + response_format.enc + if issubclass(response_format.enc, pydantic.BaseModel) + else pydantic.create_model( + "Response", value=response_format.enc, __config__={"extra": "forbid"} + ) ) messages = list(get_message_sequence().values()) response: litellm.types.utils.ModelResponse = completion( model, messages=list(messages), - response_format=response_model, + response_format=response_model if response_format.enc is not str else None, tools=list(tool_specs.values()), **kwargs, ) @@ -291,7 +295,7 @@ def call_assistant[T, U]( tool_calls.append(decoded_tool_call) result = None - if not tool_calls: + if not tool_calls and response_format.enc is not str: # return response serialized_result = message.get("content") or message.get("reasoning_content") assert isinstance(serialized_result, str), ( @@ -299,9 +303,18 @@ def call_assistant[T, U]( ) try: raw_result = response_model.model_validate_json(serialized_result) - result = response_format.decode(raw_result.value) # type: ignore + result = response_format.decode( + raw_result.value + if not issubclass(response_format.enc, pydantic.BaseModel) + else raw_result + ) # type: ignore except (pydantic.ValidationError, TypeError, ValueError, SyntaxError) as e: raise ResultDecodingError(e, raw_message=raw_message) from e + elif not tool_calls and response_format.enc is str: + # if expecting a string result, return the raw content as the result + content = message.get("content") or message.get("reasoning_content") + assert isinstance(content, str), "Expected content to be a string" + result = content return (raw_message, tool_calls, result) @@ -387,7 +400,36 @@ def flush_text() -> None: @Operation.define def call_system(template: Template) -> collections.abc.Sequence[Message]: """Get system instruction message(s) to prepend to all LLM prompts.""" - return () + + assert inspect.getdoc(type(template)) is not None + + system_prompt = inspect.cleandoc(f""" + You are responsible for implementing the `Template` '{template.__name__}' defined in the module source code below. + + First, as background, here is the class-level documentation for the `Template` class:: + + {inspect.getdoc(type(template))} + """) + + try: + system_prompt += inspect.cleandoc(f""" + Here is the source code of the module defining the `Template` instance '{template.__name__}':: + + {inspect.getsource(inspect.getmodule(template))} + """) + except (TypeError, OSError): + system_prompt += inspect.cleandoc(f""" + The source code for the module defining '{template.__name__}' is not available. + Instead, here are the signature and docstring of '{template.__name__}':: + + {template.__name__} :: {template.__signature__.format()} + + {inspect.cleandoc(template.__prompt_template__)} + """) + + msg = _make_message(dict(role="system", content=system_prompt)) + append_message(msg) + return (msg,) class RetryLLMHandler(ObjectInterpretation): diff --git a/effectful/handlers/llm/template.py b/effectful/handlers/llm/template.py index 1a74b1005..f01018105 100644 --- a/effectful/handlers/llm/template.py +++ b/effectful/handlers/llm/template.py @@ -110,6 +110,20 @@ class _BoundInstance[T]: instance: T +def _make_context_tool[T](name: str, value: T) -> Tool[[], T]: + """Create a synthetic read-only Tool for a lexical variable.""" + from effectful.internals.unification import nested_type + + def reader(): + return value + + reader.__name__ = name + reader.__doc__ = f"Read the value of lexical variable `{name}`" + reader.__annotations__ = {"return": nested_type(value).value} + + return Tool.define(reader) + + class Template[**P, T](Tool[P, T]): """A :class:`Template` is a function that is implemented by a large language model. @@ -187,14 +201,14 @@ def tools(self) -> Mapping[str, Tool]: continue # Collect tools in context - if isinstance(obj, Tool): + elif isinstance(obj, Tool): result[name] = obj - if isinstance(obj, staticmethod) and isinstance(obj.__func__, Tool): + elif isinstance(obj, staticmethod) and isinstance(obj.__func__, Tool): result[name] = obj.__func__ # Collect tools as methods on any bound instances - if isinstance(obj, _BoundInstance): + elif isinstance(obj, _BoundInstance): for instance_name in obj.instance.__dir__(): if instance_name.startswith(INSTANCE_OP_PREFIX): continue @@ -202,6 +216,18 @@ def tools(self) -> Mapping[str, Tool]: if isinstance(instance_obj, Tool): result[instance_name] = instance_obj + # Make tools for lexical variables + elif not ( + name.startswith("__") + or isinstance(obj, Operation) + or inspect.isclass(obj) + or inspect.isbuiltin(obj) + or inspect.ismodule(obj) + or inspect.isroutine(obj) + or inspect.isabstract(obj) + ): + result[name] = _make_context_tool(name, obj) + return result def __get__[S](self, instance: S | None, owner: type[S] | None = None): From fa3fb6bca0dfe460d719f24295d5f803e42a230e Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 8 Feb 2026 02:59:44 -0500 Subject: [PATCH 02/11] Add Agent class --- docs/source/agent.py | 47 ----- docs/source/agent_example.rst | 18 -- docs/source/index.rst | 1 - effectful/handlers/llm/__init__.py | 4 +- effectful/handlers/llm/template.py | 70 ++++++- tests/test_handlers_llm_agent.py | 315 +++++++++++++++++++++++++++++ 6 files changed, 384 insertions(+), 71 deletions(-) delete mode 100644 docs/source/agent.py delete mode 100644 docs/source/agent_example.rst create mode 100644 tests/test_handlers_llm_agent.py diff --git a/docs/source/agent.py b/docs/source/agent.py deleted file mode 100644 index 77d526534..000000000 --- a/docs/source/agent.py +++ /dev/null @@ -1,47 +0,0 @@ -import functools -from collections import OrderedDict - -from effectful.handlers.llm import Template -from effectful.handlers.llm.completions import ( - LiteLLMProvider, - Message, - get_message_sequence, -) -from effectful.ops.semantics import handler -from effectful.ops.types import NotHandled - - -class Agent: - __history__: OrderedDict[str, Message] - - def __init__(self): - self.__history__ = OrderedDict() # persist the list of messages - - def __init_subclass__(cls): - for method_name in dir(cls): - template = getattr(cls, method_name) - if not isinstance(template, Template): - continue - - @functools.wraps(template) - def wrapper(self, *args, **kwargs): - with handler({get_message_sequence: lambda: self.__history__}): - return template(self, *args, **kwargs) - - setattr(cls, method_name, wrapper) - - -if __name__ == "__main__": - - class ChatBot(Agent): - @Template.define - def send(self, user_input: str) -> str: - """User writes: {user_input}""" - raise NotHandled - - provider = LiteLLMProvider() - chatbot = ChatBot() - - with handler(provider): - print(chatbot.send("Hi!, how are you? I am in france.")) - print(chatbot.send("Remind me again, where am I?")) diff --git a/docs/source/agent_example.rst b/docs/source/agent_example.rst deleted file mode 100644 index a9993c568..000000000 --- a/docs/source/agent_example.rst +++ /dev/null @@ -1,18 +0,0 @@ -Contextual LLM Agents -====================== -Here we give an example of using effectful to implement chatbot-style context-aware LLM agents. - -In the code below, we define a helper class :class:`Agent` which wraps its -subclasses' template operations in a wrapper that stores and persists -the history of prior interactions with the LLM: - - :func:`_format_model_input` wraps every prompt sent to the LLM and - stashes the generated API message into a state variable. - - :func:`_compute_response` wraps the response from the LLM provider and - stashes the returned message into the state. - -Using this we can construct an agent which remembers the context of -the conversation: - -.. literalinclude:: ./agent.py - :language: python - diff --git a/docs/source/index.rst b/docs/source/index.rst index 2a5135735..33e7d3cf0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,7 +18,6 @@ Table of Contents lambda_example semi_ring_example beam_search_example - agent_example .. toctree:: :maxdepth: 2 diff --git a/effectful/handlers/llm/__init__.py b/effectful/handlers/llm/__init__.py index a87b481d6..cdda93479 100644 --- a/effectful/handlers/llm/__init__.py +++ b/effectful/handlers/llm/__init__.py @@ -1,3 +1,3 @@ -from .template import Template, Tool +from .template import Agent, Template, Tool -__all__ = ["Template", "Tool"] +__all__ = ["Agent", "Template", "Tool"] diff --git a/effectful/handlers/llm/template.py b/effectful/handlers/llm/template.py index f01018105..2dfab38f7 100644 --- a/effectful/handlers/llm/template.py +++ b/effectful/handlers/llm/template.py @@ -1,11 +1,14 @@ +import abc +import collections +import functools import inspect import types import typing -from collections import ChainMap from collections.abc import Callable, Mapping, MutableMapping from dataclasses import dataclass from typing import Annotated, Any +from effectful.ops.semantics import handler from effectful.ops.types import INSTANCE_OP_PREFIX, Annotation, Operation @@ -183,7 +186,7 @@ class Template[**P, T](Tool[P, T]): """ - __context__: ChainMap[str, Any] + __context__: collections.ChainMap[str, Any] @property def __prompt_template__(self) -> str: @@ -283,7 +286,7 @@ def define[**Q, V]( frame = frame.f_back contexts.append(globals_proxy) - context: ChainMap[str, Any] = ChainMap( + context: collections.ChainMap[str, Any] = collections.ChainMap( *typing.cast(list[MutableMapping[str, Any]], contexts) ) @@ -291,3 +294,64 @@ def define[**Q, V]( op.__context__ = context # type: ignore[attr-defined] return typing.cast(Template[Q, V], op) + + +class Agent(abc.ABC): + """Mixin that gives each instance a persistent LLM message history. + + Subclass and decorate methods with :func:`Template.define`. + Each instance accumulates messages across calls so the LLM sees + prior conversation context. + + Agents compose freely with :func:`dataclasses.dataclass` and other + base classes. Instance attributes are available in template + docstrings via ``{self.attr}``. + + Example:: + + import dataclasses + from effectful.handlers.llm import Agent, Template + from effectful.handlers.llm.completions import LiteLLMProvider + from effectful.ops.semantics import handler + from effectful.ops.types import NotHandled + + @dataclasses.dataclass + class ChatBot(Agent): + bot_name: str = dataclasses.field(default="ChatBot") + + @Template.define + def send(self, user_input: str) -> str: + \"""Friendly bot named {self.bot_name}. User writes: {user_input}\""" + raise NotHandled + + provider = LiteLLMProvider() + chatbot = ChatBot() + + with handler(provider): + chatbot.send("Hi! How are you? I am in France.") + chatbot.send("Remind me again, where am I?") # sees prior context + + """ + + __history__: collections.OrderedDict[str, Any] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + prop = functools.cached_property(lambda _: collections.OrderedDict()) + prop.__set_name__(cls, "__history__") + cls.__history__ = prop + + for name in list(cls.__dict__): + attr = cls.__dict__[name] + if not isinstance(attr, Template): + continue + _template = attr + + @functools.wraps(_template) + def wrapper(self, *args, _t=_template, **kwargs): + from effectful.handlers.llm.completions import get_message_sequence + + with handler({get_message_sequence: lambda: self.__history__}): + return _t(self, *args, **kwargs) + + setattr(cls, name, wrapper) diff --git a/tests/test_handlers_llm_agent.py b/tests/test_handlers_llm_agent.py new file mode 100644 index 000000000..38c8aacd1 --- /dev/null +++ b/tests/test_handlers_llm_agent.py @@ -0,0 +1,315 @@ +"""Tests for Agent mixin message sequence semantics.""" + +import collections +import dataclasses + +from litellm import ModelResponse + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import ( + LiteLLMProvider, + RetryLLMHandler, + completion, +) +from effectful.ops.semantics import handler +from effectful.ops.syntax import ObjectInterpretation, implements +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Helpers (same pattern as test_handlers_llm_provider.py) +# --------------------------------------------------------------------------- + + +def make_text_response(content: str) -> ModelResponse: + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + model="test-model", + ) + + +def make_tool_call_response( + tool_name: str, tool_args: str, tool_call_id: str = "call_1" +) -> ModelResponse: + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": {"name": tool_name, "arguments": tool_args}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + model="test-model", + ) + + +class MockCompletionHandler(ObjectInterpretation): + """Returns pre-configured responses and captures messages sent to the LLM.""" + + def __init__(self, responses: list[ModelResponse]): + self.responses = responses + self.call_count = 0 + self.received_messages: list[list] = [] + + @implements(completion) + def _completion(self, model, messages=None, **kwargs): + self.received_messages.append(list(messages) if messages else []) + response = self.responses[min(self.call_count, len(self.responses) - 1)] + self.call_count += 1 + return response + + +# --------------------------------------------------------------------------- +# Agent subclass used by most tests +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class ChatBot(Agent): + """Simple chat agent for testing history accumulation.""" + + bot_name: str = dataclasses.field(default="ChatBot") + + @Template.define + def send(self, user_input: str) -> str: + """A friendly bot named {self.bot_name}. User writes: {user_input}""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestAgentHistoryAccumulation: + """History accumulates across sequential calls on the same instance.""" + + def test_second_call_sees_prior_messages(self): + mock = MockCompletionHandler( + [make_text_response("hi"), make_text_response("good")] + ) + bot = ChatBot() + + with handler(LiteLLMProvider()), handler(mock): + bot.send("hello") + bot.send("how are you") + + # First call: system + user → 2 messages + assert len(mock.received_messages[0]) == 2 + + # Second call: previous system + user + assistant, PLUS new system + user → 5 + assert len(mock.received_messages[1]) > len(mock.received_messages[0]) + + # Verify roles in second call + roles = [m["role"] for m in mock.received_messages[1]] + assert roles.count("assistant") >= 1 + assert roles.count("user") >= 2 + assert roles.count("system") >= 2 + + def test_history_contains_all_messages_after_two_calls(self): + mock = MockCompletionHandler( + [make_text_response("r1"), make_text_response("r2")] + ) + bot = ChatBot() + + with handler(LiteLLMProvider()), handler(mock): + bot.send("a") + bot.send("b") + + # After two complete calls the history should have: + # call 1: system, user, assistant (3) + # call 2: system, user, assistant (3) + assert len(bot.__history__) == 6 + + def test_message_ids_are_unique(self): + mock = MockCompletionHandler( + [make_text_response("r1"), make_text_response("r2")] + ) + bot = ChatBot() + + with handler(LiteLLMProvider()), handler(mock): + bot.send("a") + bot.send("b") + + ids = list(bot.__history__.keys()) + assert len(ids) == len(set(ids)), "message IDs must be unique" + + +class TestAgentIsolation: + """Each agent instance has independent history; non-agent templates are unaffected.""" + + def test_two_agents_have_independent_histories(self): + mock = MockCompletionHandler( + [ + make_text_response("from bot1"), + make_text_response("from bot2"), + ] + ) + bot1 = ChatBot() + bot2 = ChatBot() + + with handler(LiteLLMProvider()), handler(mock): + bot1.send("msg for bot1") + bot2.send("msg for bot2") + + # bot2's call should NOT contain bot1's messages + assert len(mock.received_messages[1]) == 2 # system + user only + + # Each bot has its own history + assert len(bot1.__history__) == 3 # system, user, assistant + assert len(bot2.__history__) == 3 + + # Histories share no message IDs + assert set(bot1.__history__.keys()).isdisjoint(set(bot2.__history__.keys())) + + def test_non_agent_template_gets_fresh_sequence(self): + @Template.define + def standalone(topic: str) -> str: + """Write about {topic}.""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_text_response("agent reply"), + make_text_response("standalone reply"), + make_text_response("agent reply 2"), + ] + ) + bot = ChatBot() + + with handler(LiteLLMProvider()), handler(mock): + bot.send("hello") + standalone("fish") + bot.send("bye") + + # standalone (call index 1) should see only system + user (fresh sequence) + assert len(mock.received_messages[1]) == 2 + + # bot's third call (call index 2) should see its accumulated history + # but NOT the standalone messages + assert len(mock.received_messages[2]) == 5 # 3 from first call + 2 new + + +class TestAgentCachedProperty: + """__history__ is lazily created per instance without requiring __init__.""" + + def test_no_init_required(self): + class MinimalAgent(Agent): + @Template.define + def greet(self, name: str) -> str: + """Hello {name}.""" + raise NotHandled + + agent = MinimalAgent() + # Should be an OrderedDict, created on first access + assert isinstance(agent.__history__, collections.OrderedDict) + assert len(agent.__history__) == 0 + + def test_subclass_with_own_init(self): + class CustomAgent(Agent): + def __init__(self, name: str): + self.name = name + + @Template.define + def greet(self) -> str: + """Say hello.""" + raise NotHandled + + agent = CustomAgent("Alice") + assert agent.name == "Alice" + assert isinstance(agent.__history__, collections.OrderedDict) + + def test_history_is_per_instance(self): + a = ChatBot() + b = ChatBot() + a.__history__["fake"] = {"id": "fake", "role": "user", "content": "x"} + assert "fake" not in b.__history__ + + +class TestAgentWithToolCalls: + """Agent methods that trigger tool calls maintain correct history.""" + + def test_tool_call_results_appear_in_history(self): + @Tool.define + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + class MathAgent(Agent): + @Template.define + def compute(self, question: str) -> str: + """Answer: {question}""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_tool_call_response("add", '{"a": 2, "b": 3}'), + make_text_response("The answer is 5"), + ] + ) + agent = MathAgent() + + with handler(LiteLLMProvider()), handler(mock): + result = agent.compute("what is 2+3?") + + assert result == "The answer is 5" + + # History should contain: system, user, assistant (tool_call), + # tool (result), assistant (final) + roles = [m["role"] for m in agent.__history__.values()] + assert "tool" in roles + assert roles.count("assistant") == 2 + + +class TestAgentWithRetryHandler: + """RetryLLMHandler composes correctly with Agent history.""" + + def test_failed_retries_dont_pollute_history(self): + mock = MockCompletionHandler( + [ + # First attempt: invalid result for int + make_text_response('{"value": "not_an_int"}'), + # Retry: valid + make_text_response('{"value": 42}'), + ] + ) + + class NumberAgent(Agent): + @Template.define + def pick_number(self) -> int: + """Pick a number.""" + raise NotHandled + + agent = NumberAgent() + + with ( + handler(LiteLLMProvider()), + handler(RetryLLMHandler(num_retries=3)), + handler(mock), + ): + result = agent.pick_number() + + assert result == 42 + + # The malformed assistant message and error feedback from the retry + # should NOT appear in the agent's history. Only the final successful + # assistant message should be there. + roles = [m["role"] for m in agent.__history__.values()] + assert roles == ["system", "user", "assistant"] From eca5cb8fef22b83fadf267ba04c2dfaec3760c25 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 8 Feb 2026 13:28:31 -0500 Subject: [PATCH 03/11] stash examples --- docs/source/llm_examples/__init__.py | 0 docs/source/llm_examples/async_concurrency.py | 68 +++++ docs/source/llm_examples/batch_translate.py | 71 +++++ docs/source/llm_examples/chat_memory.py | 131 +++++++++ docs/source/llm_examples/chat_search.py | 113 ++++++++ docs/source/llm_examples/flight_booking.py | 255 ++++++++++++++++++ docs/source/llm_examples/guardrails.py | 74 +++++ .../llm_examples/hanoi_solver_iterative.py | 209 ++++++++++++++ .../llm_examples/hanoi_solver_recursive.py | 195 ++++++++++++++ docs/source/llm_examples/hitl.py | 173 ++++++++++++ docs/source/llm_examples/majority_vote.py | 88 ++++++ docs/source/llm_examples/map_reduce.py | 156 +++++++++++ docs/source/llm_examples/multi_agent.py | 164 +++++++++++ docs/source/llm_examples/rag.py | 193 +++++++++++++ docs/source/llm_examples/research_agent.py | 144 ++++++++++ docs/source/llm_examples/supervisor.py | 178 ++++++++++++ docs/source/llm_examples/tao_agent.py | 185 +++++++++++++ docs/source/llm_examples/text2sql.py | 170 ++++++++++++ docs/source/llm_examples/thinking.py | 113 ++++++++ 19 files changed, 2680 insertions(+) create mode 100644 docs/source/llm_examples/__init__.py create mode 100644 docs/source/llm_examples/async_concurrency.py create mode 100644 docs/source/llm_examples/batch_translate.py create mode 100644 docs/source/llm_examples/chat_memory.py create mode 100644 docs/source/llm_examples/chat_search.py create mode 100644 docs/source/llm_examples/flight_booking.py create mode 100644 docs/source/llm_examples/guardrails.py create mode 100644 docs/source/llm_examples/hanoi_solver_iterative.py create mode 100644 docs/source/llm_examples/hanoi_solver_recursive.py create mode 100644 docs/source/llm_examples/hitl.py create mode 100644 docs/source/llm_examples/majority_vote.py create mode 100644 docs/source/llm_examples/map_reduce.py create mode 100644 docs/source/llm_examples/multi_agent.py create mode 100644 docs/source/llm_examples/rag.py create mode 100644 docs/source/llm_examples/research_agent.py create mode 100644 docs/source/llm_examples/supervisor.py create mode 100644 docs/source/llm_examples/tao_agent.py create mode 100644 docs/source/llm_examples/text2sql.py create mode 100644 docs/source/llm_examples/thinking.py diff --git a/docs/source/llm_examples/__init__.py b/docs/source/llm_examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/source/llm_examples/async_concurrency.py b/docs/source/llm_examples/async_concurrency.py new file mode 100644 index 000000000..32ec20c77 --- /dev/null +++ b/docs/source/llm_examples/async_concurrency.py @@ -0,0 +1,68 @@ +"""Fork/join async concurrency with templates. + +Demonstrates: +- Running multiple LLM template calls concurrently with ``asyncio.gather`` +- Using ``asyncio.to_thread`` to run synchronous template calls in parallel +""" + +import argparse +import asyncio +import functools +import os + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Async template +# --------------------------------------------------------------------------- + + +@Template.define +def analyze_average_age(ages: list[int]) -> int: + """Analyze the dataset of ages {ages} and return the average age of + participants. Do not use any tools.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main(provider: LiteLLMProvider): + analysis = functools.partial( + asyncio.to_thread, handler(provider)(analyze_average_age) + ) + results = await asyncio.gather( + analysis([25, 30, 35, 40]), + analysis([20, 28, 17, 30]), + analysis([22, 27, 31, 29]), + analysis([24, 26, 32, 38]), + analysis([21, 29, 33, 37]), + ) + for i, result in enumerate(results): + print(f"Group {i}: average age = {result}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Analyze average ages concurrently") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + asyncio.run(main(provider)) diff --git a/docs/source/llm_examples/batch_translate.py b/docs/source/llm_examples/batch_translate.py new file mode 100644 index 000000000..bcdb343d3 --- /dev/null +++ b/docs/source/llm_examples/batch_translate.py @@ -0,0 +1,71 @@ +"""Batch translation with instruction injection. + +Demonstrates: +- ``@Template.define`` for a translation template with injected instructions +""" + +import argparse +import os + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.handlers.llm.evaluation import RestrictedEvalProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Translation template +# --------------------------------------------------------------------------- + + +@Template.define +def translate(target_language: str, instructions: str = "") -> Template[[str], str]: + """ + Write a `Template` that translates a string of English text into {target_language} + If any instructions are provided, include them in the prompt: {instructions} + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Batch translation with instruction injection" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--max-steps", + type=int, + default=5, + help="Maximum number of steps before giving up", + ) + parser.add_argument( + "--num-retries", + type=int, + default=5, + help="Number of retries for malformed LLM output", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RestrictedEvalProvider()): + translator = translate( + target_language="french", instructions="Use formal language." + ) + print(translator("hello, how are you? how is your day going?")) diff --git a/docs/source/llm_examples/chat_memory.py b/docs/source/llm_examples/chat_memory.py new file mode 100644 index 000000000..42c8b46ac --- /dev/null +++ b/docs/source/llm_examples/chat_memory.py @@ -0,0 +1,131 @@ +"""Chat agent with embedding-based memory. + +Demonstrates: +- A stateful chat agent that maintains conversation history +- Embedding-based retrieval of relevant past context +- Simple in-memory vector store with L2 distance +""" + +import argparse +import dataclasses +import os + +import litellm +import numpy as np + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Embedding helpers +# --------------------------------------------------------------------------- + + +def get_embedding(text: str) -> np.ndarray: + """Get an embedding vector for the given text using litellm.""" + response = litellm.embedding(model="text-embedding-ada-002", input=text) + return np.array(response.data[0]["embedding"], dtype=np.float32) + + +def find_closest( + index: list[tuple[str, np.ndarray]], phrase: str +) -> tuple[str, float] | None: + """Find the closest entry in the index to the given phrase.""" + if not index: + return None + phrase_embedding = get_embedding(phrase) + + def dist(a: np.ndarray, b: np.ndarray) -> float: + return float(((a - b) ** 2).sum()) + + return min( + ((msg, dist(embedding, phrase_embedding)) for msg, embedding in index), + key=lambda elt: elt[1], + ) + + +# --------------------------------------------------------------------------- +# Chat template +# --------------------------------------------------------------------------- + + +@Template.define +def respond_to_user( + user_message: str, relevant_context: str, prev_messages: str +) -> str: + """Given the user wrote: {user_message} + Continue the conversation. + The last few messages were: {prev_messages} + Older relevant context: {relevant_context}""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Chat agent +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class ChatAgent: + """A chat agent that compresses old messages into an embedding index.""" + + history: list[dict[str, str]] = dataclasses.field(default_factory=list) + index: list[tuple[str, np.ndarray]] = dataclasses.field(default_factory=list) + + def _compress(self): + """Move the oldest pair of messages into the embedding index.""" + oldest_pair, self.history = self.history[:2], self.history[2:] + text = "\n".join(m["content"] for m in oldest_pair) + self.index.append((text, get_embedding(text))) + + def _find_relevant(self, query: str) -> str: + result = find_closest(self.index, query) + return result[0] if result else "No relevant context." + + def chat(self, user_input: str): + relevant = self._find_relevant(user_input) + prev_messages = "\n".join( + f"{m['author']}: {m['content']}" for m in self.history + ) + response = respond_to_user(user_input, relevant, prev_messages) + self.history.append({"author": "user", "content": user_input}) + self.history.append({"author": "agent", "content": response}) + if len(self.history) > 6: + self._compress() + print(f"user: {user_input}") + print(f"agent: {response}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Chat agent with embedding-based memory" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + agent = ChatAgent() + + provider = LiteLLMProvider(model=args.model) + with handler(provider): + agent.chat("Hello! How are you doing?") + agent.chat("Lovely! I'm having a great day.") + agent.chat("What is the capital of France?") + agent.chat("I didn't know that! That's amazing!") diff --git a/docs/source/llm_examples/chat_search.py b/docs/source/llm_examples/chat_search.py new file mode 100644 index 000000000..8dcdd1691 --- /dev/null +++ b/docs/source/llm_examples/chat_search.py @@ -0,0 +1,113 @@ +import argparse +import dataclasses +import os +import urllib.parse + +import requests + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + + +@Tool.define +def search_web(query: str) -> str: + """Search Wikipedia for a topic and return a summary. The query can be a topic name or a natural language question.""" + search_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "list": "search", + "srsearch": query, + "srlimit": 1, + "format": "json", + } + ) + search_data = requests.get( + search_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + results = search_data.get("query", {}).get("search", []) + if not results: + raise ValueError(f"No results found for: {query}") + title = results[0]["title"] + + summary_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "titles": title, + "prop": "extracts", + "exintro": True, + "explaintext": True, + "format": "json", + } + ) + summary_data = requests.get( + summary_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + page = next(iter(summary_data["query"]["pages"].values())) + extract = page.get("extract", "No summary available.") + url = f"https://en.wikipedia.org/wiki/{urllib.parse.quote(title.replace(' ', '_'))}" + + return f"# {title}\n\n{extract}\n\nSource: {url}" + + +@dataclasses.dataclass +class ChatBot(Agent): + """Simple chat agent for testing history accumulation.""" + + bot_name: str = dataclasses.field(default="ChatBot") + + @Template.define + def send(self, user_input: str) -> str: + """ + You are a friendly and helpful AI assistant named {self.bot_name}. + If user input contains a question that you're not sure how to answer, + consider using the web search tool to find the answer and include it in your response. + + The user writes: + {user_input} + """ + raise NotHandled + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="LLM-guided research agent with web search" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--name", + type=str, + default="Chatty McChatface", + help="The name of the chatbot", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Run in interactive mode, allowing multiple back-and-forth messages", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + chatbot = ChatBot(bot_name=args.name) + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + if args.interactive: + while True: + print(chatbot.send(input("You: "))) + else: + print(chatbot.send("Hi! Can you tell me about the Statue of Liberty?")) + print(chatbot.send("Who designed it?")) + print(chatbot.send("What about the speed of light? How fast is it?")) diff --git a/docs/source/llm_examples/flight_booking.py b/docs/source/llm_examples/flight_booking.py new file mode 100644 index 000000000..35cdec35d --- /dev/null +++ b/docs/source/llm_examples/flight_booking.py @@ -0,0 +1,255 @@ +"""Flight booking with multi-agent delegation. + +Demonstrates: +- Multi-agent delegation: a tool that internally calls a separate + ``@Template.define`` (agent-to-agent delegation) +- Programmatic validation of LLM output with retry +- Interactive human-in-the-loop flow +- ``Agent`` history for conversational seat selection +""" + +import argparse +import dataclasses +import datetime +import enum +import os +from typing import Literal + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Structured output types +# --------------------------------------------------------------------------- + + +class Airport(enum.StrEnum): + SFO = "SFO" + ANC = "ANC" + FAI = "FAI" + JNU = "JNU" + NYC = "NYC" + LAX = "LAX" + CHI = "CHI" + MIA = "MIA" + BOS = "BOS" + SEA = "SEA" + DFW = "DFW" + DEN = "DEN" + ATL = "ATL" + HOU = "HOU" + + +@dataclasses.dataclass(frozen=True) +class FlightDetails: + flight_number: str + price: int + origin: Airport # three-letter airport code + destination: Airport # three-letter airport code + date: datetime.date # YYYY-MM-DD + + +@dataclasses.dataclass(frozen=True) +class SeatPreference: + row: int # 1-30 + seat: Literal["A", "B", "C", "D", "E", "F"] + + +# --------------------------------------------------------------------------- +# Sample data (in reality, downloaded from a booking site) +# --------------------------------------------------------------------------- + +FLIGHTS_PAGE = """\ +1. Flight SFO-AK123 - $350 - San Francisco (SFO) to Anchorage (ANC) - 2025-01-10 +2. Flight SFO-AK456 - $370 - San Francisco (SFO) to Fairbanks (FAI) - 2025-01-10 +3. Flight SFO-AK789 - $400 - San Francisco (SFO) to Juneau (JNU) - 2025-01-20 +4. Flight NYC-LA101 - $250 - San Francisco (SFO) to Anchorage (ANC) - 2025-01-10 +5. Flight CHI-MIA202 - $200 - Chicago (ORD) to Miami (MIA) - 2025-01-12 +6. Flight BOS-SEA303 - $120 - Boston (BOS) to Anchorage (ANC) - 2025-01-12 +7. Flight DFW-DEN404 - $150 - Dallas (DFW) to Denver (DEN) - 2025-01-10 +8. Flight ATL-HOU505 - $180 - Atlanta (ATL) to Houston (IAH) - 2025-01-10 +""" + +# --------------------------------------------------------------------------- +# Extraction template (inner "agent") +# --------------------------------------------------------------------------- + + +@Template.define +def extract_flights(web_page_text: str) -> list[FlightDetails]: + """Extract all flight details from the following text. + + {web_page_text} + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Tool that delegates to the extraction template +# --------------------------------------------------------------------------- + +# The tool is defined at module scope so that FlightFinder's template +# captures it via lexical scope (same pattern as search_web in other examples). + + +@Tool.define +def get_available_flights() -> list[FlightDetails]: + """Retrieve all available flights from the booking page.""" + return extract_flights(FLIGHTS_PAGE) + + +# --------------------------------------------------------------------------- +# Flight search agent +# --------------------------------------------------------------------------- + + +class FlightFinder(Agent): + """Agent that finds flights matching user criteria.""" + + @Template.define + def find_flight( + self, origin: Airport, destination: Airport, date: datetime.date + ) -> FlightDetails: + """Find the cheapest flight from {origin} to {destination} on {date}. + + Use the get_available_flights tool to retrieve all flights, then + select the cheapest one that matches the origin, destination, + and date exactly. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Seat selection agent +# --------------------------------------------------------------------------- + + +class SeatSelector(Agent): + """Agent that extracts seat preferences from natural language.""" + + @Template.define + def select_seat(self, user_input: str) -> SeatPreference: + """Extract the user's seat preference from their message. + + {user_input} + + Seats A and F are window seats. Seats C and D are aisle seats. + Row 1 is the front row with extra legroom. + Rows 14 and 20 also have extra legroom. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Validation (plain Python, no LLM needed) +# --------------------------------------------------------------------------- + + +def validate_flight( + flight: FlightDetails, origin: Airport, destination: Airport, date: datetime.date +) -> list[str]: + """Check that the selected flight matches the requested criteria.""" + errors = [] + if flight.origin != origin: + errors.append(f"origin should be {origin}, got {flight.origin}") + if flight.destination != destination: + errors.append(f"destination should be {destination}, got {flight.destination}") + if flight.date != date: + errors.append(f"date should be {date}, got {flight.date}") + return errors + + +# --------------------------------------------------------------------------- +# Booking flow +# --------------------------------------------------------------------------- + + +def book_flight( + origin: Airport, + destination: Airport, + date: datetime.date, + interactive: bool = False, + max_retries: int = 3, +) -> None: + """End-to-end flight booking with search, validation, and seat selection.""" + searcher = FlightFinder() + + # --- Search with validation retry --- + flight = None + for attempt in range(max_retries): + candidate = searcher.find_flight(origin, destination, date) + errors = validate_flight(candidate, origin, destination, date) + if errors: + print(f" [attempt {attempt}] Rejected: {'; '.join(errors)}") + continue + flight = candidate + break + + if flight is None: + print("Could not find a valid flight.") + return + + print( + f" Found: {flight.flight_number} ${flight.price} " + f"({flight.origin}->{flight.destination} on {flight.date})" + ) + + # --- User approval (interactive only) --- + if interactive: + if input(" Book this flight? (yes/no): ").strip().lower() != "yes": + print(" Cancelled.") + return + + # --- Seat selection --- + selector = SeatSelector() + seat_requests = ( + [input(" Seat preference: ")] + if interactive + else ["I'd like a window seat with extra legroom please"] + ) + for request in seat_requests: + seat = selector.select_seat(request) + print(f" Seat: row {seat.row}, seat {seat.seat}") + + print(f" Booked {flight.flight_number}, seat {seat.row}{seat.seat}!") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Flight booking with multi-agent delegation" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Run in interactive mode with user prompts", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RetryLLMHandler(num_retries=5)): + book_flight( + origin=Airport.SFO, + destination=Airport.ANC, + date=datetime.date(2025, 1, 10), + interactive=args.interactive, + ) diff --git a/docs/source/llm_examples/guardrails.py b/docs/source/llm_examples/guardrails.py new file mode 100644 index 000000000..304818ffb --- /dev/null +++ b/docs/source/llm_examples/guardrails.py @@ -0,0 +1,74 @@ +"""Travel advisor with input guardrails. + +Demonstrates: +- Using one template to validate/guard input before passing it to another +- Simple control-flow gating based on LLM classification +""" + +import argparse +import os + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Templates +# --------------------------------------------------------------------------- + + +@Template.define +def travel_query(user_query: str) -> str: + """ + Produce a concise (<100 word) answer to: {user_query} + """ + raise NotHandled + + +@Template.define +def is_safe_query(user_query: str) -> bool: + """ + Determine whether the user's query is purely related to travel advice: {user_query} + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Guarded agent +# --------------------------------------------------------------------------- + + +def answer_travel_query(user_query: str) -> str: + """Only answer travel-related queries; reject everything else.""" + if is_safe_query(user_query): + return travel_query(user_query) + else: + return f"Rejected: '{user_query}' is not related to travel advice." + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Analyze average ages concurrently") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + with handler(provider), handler(RetryLLMHandler(num_retries=5)): + print(answer_travel_query("What are great places to check out in NYC?")) + print(answer_travel_query("Should I buy apple stocks?")) diff --git a/docs/source/llm_examples/hanoi_solver_iterative.py b/docs/source/llm_examples/hanoi_solver_iterative.py new file mode 100644 index 000000000..8733c64ac --- /dev/null +++ b/docs/source/llm_examples/hanoi_solver_iterative.py @@ -0,0 +1,209 @@ +"""LLM-guided Towers of Hanoi solver with tool-based validation. + +Adapted from https://github.com/BasisResearch/effectful/pull/404 + +Demonstrates: +- A static Pydantic ``Step`` model for structured output +- ``@Tool.define`` inside a closure to expose game-state validation as a tool +- ``RetryLLMHandler`` to retry on malformed LLM output +- Templates defined inside a function that auto-capture closure-scoped tools +""" + +import argparse +import itertools +import os +from dataclasses import dataclass, field + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Step model +# --------------------------------------------------------------------------- + + +@dataclass +class Step: + """A single move: take the top disk from tower ``start`` and place it on + tower ``end``. Tower indices are zero-based.""" + + start: int + end: int + explanation: str = field(default="") # optional reasoning from the LLM + + +# --------------------------------------------------------------------------- +# Game state +# --------------------------------------------------------------------------- + + +@dataclass +class GameState: + """State of a Towers of Hanoi game. + + Higher numbers represent larger disks, so ``(2, 1, 0)`` is a valid + tower (largest on bottom). The goal is to move all disks from the + leftmost tower (index 0) to the rightmost tower (index -1). + + This is a plain ``dataclass`` (not a Pydantic model) so the type checker + can see its methods. + """ + + size: int + towers: tuple[tuple[int, ...], ...] = field(default=()) + + def __post_init__(self): + if self.size > 0 and not self.towers: + self.towers = tuple( + tuple(reversed(range(self.size))) if i == 0 else () + for i in range(self.size) + ) + + def apply(self, step: Step) -> "GameState": + """Apply a move, returning the new state. Raises ``ValueError`` if + the move is invalid.""" + start, end = step.start, step.end + if not (0 <= start < len(self.towers) and 0 <= end < len(self.towers)): + raise ValueError(f"tower index out of range: ({start}, {end})") + if len(self.towers[start]) == 0: + raise ValueError(f"tower {start} is empty") + if len(self.towers[end]) > 0 and self.towers[start][-1] > self.towers[end][-1]: + raise ValueError( + f"cannot place disk {self.towers[start][-1]} on top of " + f"disk {self.towers[end][-1]}" + ) + new_towers = [list(t) for t in self.towers] + disk = new_towers[start].pop() + new_towers[end].append(disk) + return GameState(self.size, tuple(tuple(t) for t in new_towers)) + + def is_done(self) -> bool: + return all(len(t) == 0 for t in self.towers[:-1]) and all( + self.towers[-1][i] > self.towers[-1][i + 1] + for i in range(len(self.towers[-1]) - 1) + ) + + def valid_steps(self) -> list[Step]: + steps = [] + for i, ti in enumerate(self.towers): + for j, tj in enumerate(self.towers): + if i == j or len(ti) == 0: + continue + if len(tj) == 0 or ti[-1] < tj[-1]: + steps.append(Step(i, j)) + return steps + + def __str__(self) -> str: + return " | ".join(str(list(t)) for t in self.towers) + + +# --------------------------------------------------------------------------- +# LLM move predictor +# --------------------------------------------------------------------------- + + +def predict_next_step(state: GameState) -> Step: + """Ask the LLM to predict the next move. + + A ``get_valid_moves`` tool is defined in the closure so the template + can query which moves are legal for the current game state. A + ``validate_move`` tool checks whether a proposed move is legal and + raises ``ValueError`` if not — when wrapped by ``RetryLLMHandler``, + this error is fed back to the LLM so it can correct itself. + """ + valid = state.valid_steps() + + @Tool.define + def get_valid_moves() -> list[Step]: + """Return the list of valid moves for the current game state.""" + return valid + + @Tool.define + def validate_move(proposed: Step) -> bool: + """Check whether moving from tower ``start`` to tower ``end`` is legal.""" + return proposed in state.valid_steps() + + @Template.define + def predict(game_state: GameState) -> Step: + """Given the state of the game of Towers of Hanoi: + + {game_state} + + Predict the next step to complete the game (move all disks to the + rightmost tower). You MUST call get_valid_moves first to see which + moves are legal, then pick the best one. Give a brief reasoning. + """ + raise NotHandled + + return predict(state) + + +# --------------------------------------------------------------------------- +# Solver loop +# --------------------------------------------------------------------------- + + +def solve_hanoi(state: GameState, max_steps: int = 30): + """Solve Towers of Hanoi by repeatedly asking the LLM for the next move.""" + for i in itertools.count(): + print(f"step {i}: {state}") + if state.is_done(): + print("Solved!") + return + if i >= max_steps: + print("Gave up after max steps.") + return + + step: Step = predict_next_step(state) + try: + state = state.apply(step) + print(f" move: {step.start} -> {step.end}") + except ValueError as e: + print(f" attempt {i}: invalid move {step}: {e}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LLM-guided Towers of Hanoi solver") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--game-size", + type=int, + default=3, + help="Number of disks in the Towers of Hanoi game", + ) + parser.add_argument( + "--max-steps", + type=int, + default=30, + help="Maximum number of steps before giving up", + ) + parser.add_argument( + "--num-retries", + type=int, + default=5, + help="Number of retries for malformed LLM output", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RetryLLMHandler(num_retries=args.num_retries)): + solve_hanoi(GameState(size=args.game_size), max_steps=args.max_steps) diff --git a/docs/source/llm_examples/hanoi_solver_recursive.py b/docs/source/llm_examples/hanoi_solver_recursive.py new file mode 100644 index 000000000..b5da9107a --- /dev/null +++ b/docs/source/llm_examples/hanoi_solver_recursive.py @@ -0,0 +1,195 @@ +"""Recursive LLM-based Towers of Hanoi solver. + +Adapted from https://github.com/BasisResearch/effectful/pull/404 + +Demonstrates: +- ``IsRecursive`` annotation to let a template call itself as a tool +- Recursive problem decomposition via LLM tool calls +- Post-hoc validation of the LLM-generated move sequence + +The classic recursive algorithm for Tower of Hanoi is: + + hanoi(n, source, target, auxiliary): + if n == 1: move disk from source to target + else: + hanoi(n-1, source, auxiliary, target) # move n-1 disks out of the way + move largest disk from source to target # move the bottom disk + hanoi(n-1, auxiliary, target, source) # move n-1 disks to target + +This solver defines a recursive ``Template`` that can call itself as a tool. +The LLM decomposes the n-disk problem into three sub-steps, making recursive +tool calls for the (n-1)-disk sub-problems, and returns the concatenated +list of moves. + +See: https://en.wikipedia.org/wiki/Tower_of_Hanoi +""" + +import argparse +import os +import typing +from dataclasses import dataclass, field + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.handlers.llm.template import IsRecursive +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Step model +# --------------------------------------------------------------------------- + + +@dataclass +class Step: + """A single move: take the top disk from tower ``start`` and place it on + tower ``end``. Tower indices are zero-based.""" + + start: int + end: int + + +# --------------------------------------------------------------------------- +# Game state (for validation only) +# --------------------------------------------------------------------------- + + +@dataclass +class GameState: + """State of a Towers of Hanoi game. + + Higher numbers represent larger disks, so ``(2, 1, 0)`` is a valid + tower (largest on bottom). The goal is to move all disks from the + leftmost tower (index 0) to the rightmost tower (index -1). + """ + + size: int + towers: tuple[tuple[int, ...], ...] = field(default=()) + + def __post_init__(self): + if self.size > 0 and not self.towers: + self.towers = tuple( + tuple(reversed(range(self.size))) if i == 0 else () + for i in range(self.size) + ) + + def apply(self, step: Step) -> "GameState": + """Apply a move, returning the new state. Raises ``ValueError`` if + the move is invalid.""" + start, end = step.start, step.end + if not (0 <= start < len(self.towers) and 0 <= end < len(self.towers)): + raise ValueError(f"tower index out of range: ({start}, {end})") + if len(self.towers[start]) == 0: + raise ValueError(f"tower {start} is empty") + if len(self.towers[end]) > 0 and self.towers[start][-1] > self.towers[end][-1]: + raise ValueError( + f"cannot place disk {self.towers[start][-1]} on top of " + f"disk {self.towers[end][-1]}" + ) + new_towers = [list(t) for t in self.towers] + disk = new_towers[start].pop() + new_towers[end].append(disk) + return GameState(self.size, tuple(tuple(t) for t in new_towers)) + + def is_done(self) -> bool: + return all(len(t) == 0 for t in self.towers[:-1]) and all( + self.towers[-1][i] > self.towers[-1][i + 1] + for i in range(len(self.towers[-1]) - 1) + ) + + def __str__(self) -> str: + return " | ".join(str(list(t)) for t in self.towers) + + +# --------------------------------------------------------------------------- +# Recursive LLM solver +# --------------------------------------------------------------------------- + + +@Template.define +def solve( + n_disks: int, source: int, target: int, auxiliary: int +) -> typing.Annotated[list[Step], IsRecursive]: + """Solve Tower of Hanoi: move {n_disks} disks from tower {source} to + tower {target}, using tower {auxiliary} as temporary storage. + + Recursive strategy: + - Base case (n_disks == 1): return [Step(start=source, end=target)] + - Recursive case (n_disks > 1): + 1. Call solve(n_disks - 1, source, auxiliary, target) to move the + top n_disks-1 disks out of the way onto the auxiliary tower. + 2. Move the largest disk: Step(start=source, end=target). + 3. Call solve(n_disks - 1, auxiliary, target, source) to move the + n_disks-1 disks from auxiliary to the target tower. + 4. Return the concatenated list of all steps from (1), (2), and (3). + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def validate_solution(size: int, steps: list[Step]) -> bool: + """Apply all steps to the initial state and check that the puzzle is solved.""" + state = GameState(size=size) + print(f" initial: {state}") + for i, step in enumerate(steps): + try: + state = state.apply(step) + print(f" step {i}: move {step.start} -> {step.end} => {state}") + except ValueError as e: + print(f" step {i}: INVALID move {step.start} -> {step.end}: {e}") + return False + if state.is_done(): + print(f" Solved in {len(steps)} moves!") + return True + else: + print(f" Not solved after {len(steps)} moves. Final state: {state}") + return False + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Recursive LLM-based Towers of Hanoi solver" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--game-size", + type=int, + default=3, + help="Number of disks in the Towers of Hanoi game", + ) + parser.add_argument( + "--num-retries", + type=int, + default=5, + help="Number of retries for malformed LLM output", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RetryLLMHandler(num_retries=args.num_retries)): + n = args.game_size + print(f"Solving Tower of Hanoi with {n} disks...") + steps = solve(n_disks=n, source=0, target=n - 1, auxiliary=1) + print(f"\nLLM returned {len(steps)} steps. Validating...\n") + validate_solution(n, steps) diff --git a/docs/source/llm_examples/hitl.py b/docs/source/llm_examples/hitl.py new file mode 100644 index 000000000..5b2ebe17c --- /dev/null +++ b/docs/source/llm_examples/hitl.py @@ -0,0 +1,173 @@ +"""Human-in-the-loop task planner. + +Demonstrates: +- An ``Agent`` that proposes a plan of action steps +- Human approval/rejection of each step before execution +- Feedback from rejection is fed back to the agent via history +- ``@Tool.define`` for executing approved actions +- Non-interactive mode for testing (auto-approves all steps) +""" + +import argparse +import dataclasses +import enum +import os + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Structured output +# --------------------------------------------------------------------------- + + +class ActionType(enum.StrEnum): + send_email = "send_email" + create_file = "create_file" + schedule_meeting = "schedule_meeting" + done = "done" + + +@dataclasses.dataclass(frozen=True) +class ProposedAction: + action: ActionType + description: str + details: str + + +# --------------------------------------------------------------------------- +# Simulated action execution +# --------------------------------------------------------------------------- + + +execution_log: list[str] = [] + + +@Tool.define +def execute_action(action: ActionType, details: str) -> str: + """Execute an approved action. Returns a confirmation message.""" + msg = f"[executed] {action}: {details}" + execution_log.append(msg) + return msg + + +# --------------------------------------------------------------------------- +# Planner agent +# --------------------------------------------------------------------------- + + +class Planner(Agent): + """Agent that proposes actions one at a time for human approval.""" + + @Template.define + def propose_next(self, task: str, feedback: str) -> ProposedAction: + """You are a task planner helping the user accomplish a goal. + + Task: {task} + + Feedback from the last step: {feedback} + + Review the conversation history for previously completed actions. + Propose the next action to take. If the task is complete, + set action to "done". + + If a previous proposal was rejected, propose something different + that addresses the feedback. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Human-in-the-loop execution +# --------------------------------------------------------------------------- + + +def run_with_approval( + task: str, interactive: bool = False, max_steps: int = 5 +) -> list[str]: + """Run a task planner with human approval for each step.""" + planner = Planner() + feedback = "No actions taken yet. Start planning." + + for step in range(max_steps): + proposal = planner.propose_next(task, feedback) + + if proposal.action == ActionType.done: + print(f" [step {step + 1}] Done: {proposal.description}") + break + + print( + f" [step {step + 1}] Proposed: {proposal.action} - {proposal.description}" + ) + print(f" Details: {proposal.details}") + + if interactive: + answer = input(" Approve? (yes/no + reason): ").strip() + approved = answer.lower().startswith("y") + else: + answer = "yes" + approved = True + + if approved: + result = execute_action(proposal.action, proposal.details) + print(f" {result}") + feedback = f"Approved and executed: {result}" + else: + print(f" [rejected] {answer}") + feedback = f"Rejected: {answer}" + + return list(execution_log) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Human-in-the-loop task planner") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Run in interactive mode with human approval prompts", + ) + parser.add_argument( + "--max-steps", + type=int, + default=5, + help="Maximum number of action steps", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + task = ( + "Organize a team lunch for next Friday. " + "Send an email to the team, create a shared document for " + "restaurant suggestions, and schedule a meeting to finalize plans." + ) + + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + print(f"Task: {task}\n") + log = run_with_approval( + task, + interactive=args.interactive, + max_steps=args.max_steps, + ) + print(f"\nExecution log ({len(log)} actions):") + for entry in log: + print(f" {entry}") diff --git a/docs/source/llm_examples/majority_vote.py b/docs/source/llm_examples/majority_vote.py new file mode 100644 index 000000000..1ad8c6296 --- /dev/null +++ b/docs/source/llm_examples/majority_vote.py @@ -0,0 +1,88 @@ +"""Majority voting ensemble. + +Demonstrates: +- Running the same template multiple times and taking a majority vote +- ``collections.Counter`` for tallying responses +""" + +import argparse +import collections +import collections.abc +import enum +import os + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Template +# --------------------------------------------------------------------------- + + +class Answer(enum.StrEnum): + yes = "yes" + no = "no" + maybe = "maybe" + + +@Template.define +def yes_or_no(question: str) -> Answer: + """ + Answer the following yes/no/maybe question: {question} + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Majority vote +# --------------------------------------------------------------------------- + + +def majority_vote[Q]( + oracle: collections.abc.Callable[[Q], Answer], query: Q, voters: int = 3 +) -> tuple[Answer, int]: + """Call ``oracle(query)`` multiple times and return the most common answer.""" + counter = collections.Counter(oracle(query) for _ in range(voters)) + return counter.most_common(1)[0] + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Majority voting ensemble for yes/no questions" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--num-voters", type=int, default=3, help="Number of voters for majority vote" + ) + parser.add_argument( + "--question", + type=str, + default="Is Paris the capital of France?", + help="Yes/no question to ask", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + with handler(provider): + answer, count = majority_vote(yes_or_no, args.question, voters=args.num_voters) + print( + f"Question: {args.question}\nAnswer: {answer} (voted {count}/{args.num_voters})" + ) diff --git a/docs/source/llm_examples/map_reduce.py b/docs/source/llm_examples/map_reduce.py new file mode 100644 index 000000000..f7e00ac1d --- /dev/null +++ b/docs/source/llm_examples/map_reduce.py @@ -0,0 +1,156 @@ +"""Map-reduce resume evaluation. + +Demonstrates: +- Fan-out: evaluating multiple items independently with the same template +- Reduce: aggregating individual results into a summary +- ``asyncio.gather`` with ``asyncio.to_thread`` for parallel LLM calls +- Structured output with dataclasses +""" + +import argparse +import asyncio +import dataclasses +import functools +import os + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Structured output +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class Evaluation: + name: str + qualified: bool + strengths: str + weaknesses: str + score: int # 1-10 + + +# --------------------------------------------------------------------------- +# Templates +# --------------------------------------------------------------------------- + + +@Template.define +def evaluate_resume(resume: str, job_description: str) -> Evaluation: + """You are a hiring manager. Evaluate this resume against the job + description and produce a structured evaluation. + + Job description: {job_description} + + Resume: + {resume} + + Score from 1 (poor fit) to 10 (perfect fit). + """ + raise NotHandled + + +@Template.define +def summarize_evaluations(job_description: str, evaluations_text: str) -> str: + """You are a hiring manager summarizing candidate evaluations. + + Job description: {job_description} + + Individual evaluations: + {evaluations_text} + + Provide a brief summary: rank the candidates from best to worst, + highlight the top candidate, and note any concerns. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Sample data +# --------------------------------------------------------------------------- + +JOB_DESCRIPTION = ( + "Senior Python Developer: 5+ years Python experience, " + "familiarity with web frameworks (Django/Flask), " + "database design, and cloud deployment (AWS/GCP)." +) + +RESUMES = [ + "Alice Chen - 7 years Python, Django expert, AWS certified, " + "led team of 5, built microservices architecture at FinTech startup.", + "Bob Smith - 3 years Python, 2 years JavaScript, some Flask experience, " + "junior developer at small agency, strong communication skills.", + "Carol Davis - 10 years software engineering, 6 years Python, " + "GCP specialist, PostgreSQL expert, open-source contributor, " + "previously senior engineer at Google.", + "Dave Wilson - 4 years Python, self-taught, built several side projects, " + "no professional experience with web frameworks or cloud platforms.", +] + +# --------------------------------------------------------------------------- +# Map-reduce pipeline +# --------------------------------------------------------------------------- + + +async def map_reduce_evaluate( + provider: LiteLLMProvider, + resumes: list[str], + job_description: str, +) -> str: + """Evaluate resumes in parallel (map), then summarize (reduce).""" + # Map: evaluate each resume concurrently + evaluate = functools.partial( + asyncio.to_thread, + handler(provider)(handler(RetryLLMHandler(num_retries=3))(evaluate_resume)), + ) + evaluations: list[Evaluation] = list( + await asyncio.gather(*(evaluate(resume, job_description) for resume in resumes)) + ) + + # Print individual evaluations + for ev in evaluations: + print(f" {ev.name}: score={ev.score}/10, qualified={ev.qualified}") + print(f" + {ev.strengths}") + print(f" - {ev.weaknesses}") + + # Reduce: summarize all evaluations + evaluations_text = "\n\n".join( + f"Candidate: {ev.name}\n" + f"Score: {ev.score}/10\n" + f"Qualified: {ev.qualified}\n" + f"Strengths: {ev.strengths}\n" + f"Weaknesses: {ev.weaknesses}" + for ev in evaluations + ) + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + return summarize_evaluations(job_description, evaluations_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Map-reduce resume evaluation") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + print(f"Evaluating {len(RESUMES)} resumes for: {JOB_DESCRIPTION}\n") + summary = asyncio.run(map_reduce_evaluate(provider, RESUMES, JOB_DESCRIPTION)) + print(f"\n{summary}") diff --git a/docs/source/llm_examples/multi_agent.py b/docs/source/llm_examples/multi_agent.py new file mode 100644 index 000000000..448cf7a4f --- /dev/null +++ b/docs/source/llm_examples/multi_agent.py @@ -0,0 +1,164 @@ +"""Multi-agent Taboo word guessing game. + +Demonstrates: +- Two ``Agent`` instances with independent conversation histories +- Inter-agent communication via plain function calls +- Each agent has a different persona and goal +- ``Agent.__history__`` keeps each agent's context isolated +""" + +import argparse +import dataclasses +import enum +import os + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Structured output +# --------------------------------------------------------------------------- + + +class Confidence(enum.Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +@dataclasses.dataclass(frozen=True) +class Guess: + guess: str + confidence: Confidence + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class Hinter(Agent): + """Agent that gives hints about a secret word without saying it.""" + + secret_word: str = dataclasses.field(default="") + taboo_words: list[str] = dataclasses.field(default_factory=list) + + @Tool.define + def is_taboo(self, hint: str) -> bool: + """Check if the given hint contains any taboo words or the secret word.""" + lowered_hint = hint.lower() + if self.secret_word.lower() in lowered_hint: + return True + for taboo in self.taboo_words: + if taboo.lower() in lowered_hint: + return True + return False + + @Template.define + def give_hint(self, guesser_response: str) -> str: + """You are playing a word guessing game. You must help the guesser + figure out the secret word by giving creative hints. + + RULES: + - You MUST NOT say the secret word: {self.secret_word} + - You MUST NOT use any of these taboo words: {self.taboo_words} + - Give a single, concise hint (one sentence) + - Review conversation history to avoid repeating hints + - Use the is_taboo tool to check if your hint is valid + + The guesser's last response was: {guesser_response} + """ + raise NotHandled + + +class Guesser(Agent): + """Agent that tries to guess the secret word from hints.""" + + @Template.define + def make_guess(self, hint: str) -> Guess: + """You are playing a word guessing game. Based on the hints you've + received, guess the secret word. + + Latest hint: {hint} + + Review the conversation history for all previous hints. + Make your best guess. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Game loop +# --------------------------------------------------------------------------- + + +def play_taboo( + secret_word: str, + taboo_words: list[str], + max_rounds: int = 5, +) -> bool: + """Play a round of Taboo between a hinter and a guesser.""" + hinter = Hinter(secret_word=secret_word, taboo_words=taboo_words) + guesser = Guesser() + + guesser_response = "I'm ready to guess!" + + for round_num in range(max_rounds): + # Hinter gives a hint + hint = hinter.give_hint(guesser_response) + print(f" [round {round_num}] Hinter: {hint}") + + # Guesser tries to guess + guess = guesser.make_guess(hint) + guesser_response = f"I guessed '{guess.guess}' ({guess.confidence})" + print(f" [round {round_num}] Guesser: {guess.guess} ({guess.confidence})") + + if guess.guess.lower().strip() == secret_word.lower(): + print(f" Correct! Guessed in {round_num} round(s).") + return True + + print(f" Failed to guess '{secret_word}' in {max_rounds} rounds.") + return False + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Multi-agent Taboo word guessing game") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--max-rounds", + type=int, + default=5, + help="Maximum rounds per game", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + games = [ + ("piano", ["music", "keys", "instrument", "play"]), + ("volcano", ["lava", "eruption", "mountain", "hot"]), + ] + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + for secret, taboo in games: + print(f"\nGame: '{secret}' (taboo: {taboo})") + play_taboo(secret, taboo, max_rounds=args.max_rounds) diff --git a/docs/source/llm_examples/rag.py b/docs/source/llm_examples/rag.py new file mode 100644 index 000000000..166f1dbb4 --- /dev/null +++ b/docs/source/llm_examples/rag.py @@ -0,0 +1,193 @@ +"""Retrieval-augmented generation (RAG). + +Demonstrates: +- Offline: chunking documents, embedding, and indexing +- Online: embedding a query, retrieving relevant chunks, and generating + a grounded answer +- ``@Tool.define`` to expose retrieval as a tool the LLM can call +- Separation of indexing (plain Python) from generation (``@Template.define``) +""" + +import argparse +import dataclasses +import os + +import litellm +import numpy as np + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Embedding helpers +# --------------------------------------------------------------------------- + + +def get_embedding(text: str, model: str) -> np.ndarray: + """Get an embedding vector for the given text using litellm.""" + response = litellm.embedding(model=model, input=text) + return np.array(response.data[0]["embedding"], dtype=np.float32) + + +# --------------------------------------------------------------------------- +# Vector index +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class VectorIndex: + """Simple in-memory vector index using L2 distance.""" + + model: str + chunks: list[str] = dataclasses.field(default_factory=list) + embeddings: list[np.ndarray] = dataclasses.field(default_factory=list) + + def add(self, text: str) -> None: + """Add a text chunk to the index.""" + self.chunks.append(text) + self.embeddings.append(get_embedding(text, model=self.model)) + + @Tool.define + def retrieve(self, query: str, top_k: int = 3) -> list[str]: + """Return the top-k most similar chunks to the query.""" + if not self.embeddings: + return [] + query_emb = get_embedding(query, model=self.model) + distances = [float(((emb - query_emb) ** 2).sum()) for emb in self.embeddings] + indices = sorted(range(len(distances)), key=lambda i: distances[i]) + return [self.chunks[i] for i in indices[:top_k]] + + +# --------------------------------------------------------------------------- +# Chunking +# --------------------------------------------------------------------------- + + +def chunk_text(text: str, chunk_size: int = 200, overlap: int = 50) -> list[str]: + """Split text into overlapping word-level chunks.""" + words = text.split() + chunks = [] + start = 0 + while start < len(words): + end = start + chunk_size + chunks.append(" ".join(words[start:end])) + start += chunk_size - overlap + return chunks + + +# --------------------------------------------------------------------------- +# Sample documents +# --------------------------------------------------------------------------- + +DOCUMENTS = [ + """The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars + in Paris, France. It is named after the engineer Gustave Eiffel, whose + company designed and built the tower from 1887 to 1889 as the centerpiece + of the 1889 World's Fair. Although initially criticized by some of France's + leading artists and intellectuals, the tower has become a global icon of + France and one of the most recognizable structures in the world. The tower + is 330 metres tall, about the same height as an 81-storey building, and + is the tallest structure in Paris. It was the first structure in the world + to reach a height of 300 metres.""", + """The Great Wall of China is a series of fortifications that were built + across the historical northern borders of ancient Chinese states and + Imperial China as protection against various nomadic groups. The total + length of all sections ever built is more than 20,000 km. Several walls + were built from as early as the 7th century BC, with selective stretches + later joined together by Qin Shi Huang, the first emperor of China. The + best-preserved sections of the wall date from the Ming dynasty + (1368-1644). The wall's purpose was defensive, and it featured + watchtowers, troop barracks, and signaling capabilities.""", + """The Colosseum, also known as the Flavian Amphitheatre, is an oval + amphitheatre in the centre of the city of Rome, Italy. It is the largest + ancient amphitheatre ever built, and is still the largest standing + amphitheatre in the world, despite its age. Construction began under + the emperor Vespasian in AD 72 and was completed in AD 80 under his + successor and heir, Titus. The Colosseum could hold an estimated 50,000 + to 80,000 spectators at various points in its history, and was used for + gladiatorial contests and public spectacles including animal hunts, + executions, re-enactments of famous battles, and dramas.""", +] + +# --------------------------------------------------------------------------- +# Build the index (offline phase) +# --------------------------------------------------------------------------- + + +def build_index(documents: list[str], embedding_model: str) -> VectorIndex: + """Chunk and index a collection of documents.""" + index = VectorIndex(model=embedding_model) + for doc in documents: + for chunk in chunk_text(doc, chunk_size=60, overlap=15): + index.add(chunk) + print(f"Indexed {len(index.chunks)} chunks from {len(documents)} documents") + return index + + +# --------------------------------------------------------------------------- +# RAG query (online phase) +# --------------------------------------------------------------------------- + + +@Template.define +def answer_question(question: str) -> str: + """You are a helpful assistant. Answer the user's question using ONLY + information retrieved from the knowledge base via the retrieve tool. + + If the retrieved information doesn't contain the answer, say so. + Always cite which document your information comes from. + + Question: {question} + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Retrieval-augmented generation (RAG)") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--embedding-model", + type=str, + default="lm_studio/nomic-ai/nomic-embed-text-v1.5-GGUF", + help="Embedding model to use", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + # Offline: build the index + index = build_index(DOCUMENTS, embedding_model=args.embedding_model) + + # Create the retrieval tool bound to our index + retrieve: Tool = index.retrieve + + # Online: answer questions + questions = [ + "How tall is the Eiffel Tower?", + "When was the Great Wall of China built?", + "How many spectators could the Colosseum hold?", + ] + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + for question in questions: + print(f"\nQ: {question}") + answer = answer_question(question) + print(f"A: {answer}") diff --git a/docs/source/llm_examples/research_agent.py b/docs/source/llm_examples/research_agent.py new file mode 100644 index 000000000..bc193db1b --- /dev/null +++ b/docs/source/llm_examples/research_agent.py @@ -0,0 +1,144 @@ +"""Research agent with web search. + +Demonstrates: +- ``@defop`` + ``ObjectInterpretation`` to define a pluggable web search effect +- ``@Template.define`` for LLM-implemented answer/refine/judge templates +- Handler composition: stacking a search provider alongside an LLM provider +- Iterative refinement loop: answer → judge → refine → judge → ... +""" + +import argparse +import os +import urllib.parse + +import requests + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import ( + LiteLLMProvider, +) +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Search effect + handler +# --------------------------------------------------------------------------- + + +@Tool.define +def search_web(query: str) -> str: + """Search Wikipedia for a topic and return a summary. The query can be a topic name or a natural language question.""" + search_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "list": "search", + "srsearch": query, + "srlimit": 1, + "format": "json", + } + ) + search_data = requests.get( + search_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + results = search_data.get("query", {}).get("search", []) + if not results: + return f"No results found for: {query}" + title = results[0]["title"] + + summary_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "titles": title, + "prop": "extracts", + "exintro": True, + "explaintext": True, + "format": "json", + } + ) + summary_data = requests.get( + summary_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + page = next(iter(summary_data["query"]["pages"].values())) + extract = page.get("extract", "No summary available.") + url = f"https://en.wikipedia.org/wiki/{urllib.parse.quote(title.replace(' ', '_'))}" + + return f"# {title}\n\n{extract}\n\nSource: {url}" + + +# --------------------------------------------------------------------------- +# Templates (auto-capture `search_web` from lexical scope) +# --------------------------------------------------------------------------- + + +@Template.define +def answer_question(question: str) -> str: + """Acting as a research assistant that can search the web, + construct an answer to the user's question: {question}.""" + raise NotHandled + + +@Template.define +def refine_answer(question: str, answer: str) -> str: + """Acting as a research assistant that can search the web, + given the user's original question ({question}), + refine this previous answer: {answer}.""" + raise NotHandled + + +@Template.define +def is_question_answered(question: str, answer: str) -> bool: + """Acting as a research assistant, decide if the user's question + ({question}) is appropriately answered by: {answer}. + Respond only true or false.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Agent loop +# --------------------------------------------------------------------------- + + +def research_agent(question: str, max_attempts: int = 3) -> str: + """Answer a question, iteratively refining until satisfactory.""" + answer = answer_question(question) + for _ in range(max_attempts): + if is_question_answered(question, answer): + break + answer = refine_answer(question, answer) + return answer + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="LLM-guided research agent with web search" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--question", + type=str, + default="What is the meaning of life?", + help="The question to research", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + with handler(provider): + result = research_agent(args.question) + print(result) diff --git a/docs/source/llm_examples/supervisor.py b/docs/source/llm_examples/supervisor.py new file mode 100644 index 000000000..1009d6a62 --- /dev/null +++ b/docs/source/llm_examples/supervisor.py @@ -0,0 +1,178 @@ +"""Supervisor quality-control wrapper. + +Demonstrates: +- Wrapping an agent's output with a quality-control check +- Using one ``Template`` to judge another's output +- Retry loop driven by LLM-based evaluation +""" + +import argparse +import dataclasses +import os +import urllib.parse + +import requests + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Search tool +# --------------------------------------------------------------------------- + + +@Tool.define +def search_web(query: str) -> str: + """Search Wikipedia for a topic and return a summary. The query can be a topic name or a natural language question.""" + search_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "list": "search", + "srsearch": query, + "srlimit": 1, + "format": "json", + } + ) + search_data = requests.get( + search_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + results = search_data.get("query", {}).get("search", []) + if not results: + return f"No results found for: {query}" + title = results[0]["title"] + + summary_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "titles": title, + "prop": "extracts", + "exintro": True, + "explaintext": True, + "format": "json", + } + ) + summary_data = requests.get( + summary_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + page = next(iter(summary_data["query"]["pages"].values())) + extract = page.get("extract", "No summary available.") + url = f"https://en.wikipedia.org/wiki/{urllib.parse.quote(title.replace(' ', '_'))}" + + return f"# {title}\n\n{extract}\n\nSource: {url}" + + +# --------------------------------------------------------------------------- +# Structured output for quality judgment +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class QualityJudgment: + is_acceptable: bool + feedback: str + + +# --------------------------------------------------------------------------- +# Research agent +# --------------------------------------------------------------------------- + + +class Researcher(Agent): + """Agent that answers research questions using web search.""" + + @Template.define + def answer(self, question: str) -> str: + """You are a research assistant. Answer the following question using + the search tool to find accurate information. + + Question: {question} + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Supervisor (quality judge) +# --------------------------------------------------------------------------- + + +@Template.define +def judge_quality(question: str, answer: str) -> QualityJudgment: + """You are a strict quality reviewer. Evaluate whether this answer + adequately addresses the question with accurate, specific information. + + Question: {question} + Answer: {answer} + + An answer is acceptable if it contains specific facts (names, dates, + numbers) relevant to the question. Vague or generic answers should + be rejected. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Supervised agent loop +# --------------------------------------------------------------------------- + + +def supervised_research(question: str, max_retries: int = 3) -> str: + """Answer a question with quality-control supervision. + + The researcher agent answers, the supervisor judges quality, + and if rejected the researcher tries again with feedback. + """ + researcher = Researcher() + + for attempt in range(max_retries + 1): + answer = researcher.answer(question) + judgment = judge_quality(question, answer) + + if judgment.is_acceptable: + print(f"[supervisor] Accepted on attempt {attempt + 1}") + return answer + + print(f"[supervisor] Rejected attempt {attempt + 1}: {judgment.feedback}") + + print("[supervisor] Returning best effort after max retries") + return answer + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Supervised research agent with quality control" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--max-retries", + type=int, + default=3, + help="Maximum number of supervisor rejections before accepting", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + result = supervised_research( + "What year was the Eiffel Tower completed and how tall is it?", + max_retries=args.max_retries, + ) + print(f"\nFinal answer: {result}") diff --git a/docs/source/llm_examples/tao_agent.py b/docs/source/llm_examples/tao_agent.py new file mode 100644 index 000000000..8f9fbb0c1 --- /dev/null +++ b/docs/source/llm_examples/tao_agent.py @@ -0,0 +1,185 @@ +"""Think-Act-Observe chain-of-thought agent. + +Demonstrates: +- ``Agent`` mixin for persistent conversation history +- Structured output with Pydantic models (``AgentThought``) +- A think → act → observe reasoning loop +- Pattern matching for action dispatch +""" + +import argparse +import dataclasses +import enum +import os +import urllib.parse + +import requests + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import ( + LiteLLMProvider, + RetryLLMHandler, +) +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Search tool +# --------------------------------------------------------------------------- + + +@Tool.define +def search_web(query: str) -> str: + """Search Wikipedia for a topic and return a summary. The query can be a topic name or a natural language question.""" + search_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "list": "search", + "srsearch": query, + "srlimit": 1, + "format": "json", + } + ) + search_data = requests.get( + search_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + results = search_data.get("query", {}).get("search", []) + if not results: + return f"No results found for: {query}" + title = results[0]["title"] + + summary_url = "https://en.wikipedia.org/w/api.php?" + urllib.parse.urlencode( + { + "action": "query", + "titles": title, + "prop": "extracts", + "exintro": True, + "explaintext": True, + "format": "json", + } + ) + summary_data = requests.get( + summary_url, headers={"User-Agent": "effectful-example/1.0"} + ).json() + page = next(iter(summary_data["query"]["pages"].values())) + extract = page.get("extract", "No summary available.") + url = f"https://en.wikipedia.org/wiki/{urllib.parse.quote(title.replace(' ', '_'))}" + + return f"# {title}\n\n{extract}\n\nSource: {url}" + + +# --------------------------------------------------------------------------- +# Structured output types +# --------------------------------------------------------------------------- + + +class AgentAction(enum.StrEnum): + search_the_web = "search_the_web" + calculate = "calculate" + answer = "answer" + + +@dataclasses.dataclass(frozen=True) +class AgentThought: + thinking: str + action: AgentAction + action_input: str + is_final: bool + + +# --------------------------------------------------------------------------- +# TAO Agent +# --------------------------------------------------------------------------- + + +class TAOAgent(Agent): + """Think-Act-Observe agent that reasons step by step.""" + + @Template.define + def think(self, query: str) -> AgentThought: + """You are an AI assistant solving a problem. Based on the user's query + ({query}) and prior conversation context, think about what action to + take next. + """ + raise NotHandled + + @Template.define + def observe(self, action: str, action_input: str, action_result: str) -> str: + """You are an observer. Provide a concise, objective observation of this result. + + Action: {action} + Action input: {action_input} + Action result: {action_result} + + + Do not make decisions, just describe what you see. + + """ + raise NotHandled + + def run(self, query: str, max_steps: int = 5) -> str: + result = "" + for _ in range(max_steps): + thought = self.think(query) + result = self._act(thought.action, thought.action_input) + self.observe(str(thought.action), thought.action_input, result) + if thought.is_final: + break + return result + + def _act(self, action: AgentAction, action_input: str) -> str: + match action: + case AgentAction.search_the_web: + return search_web(action_input) + case AgentAction.calculate: + try: + return action_input # eval(action_input)) # noqa: S307 + except Exception as e: + return str(e) + case AgentAction.answer: + return action_input + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="TAO chain-of-thought agent") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--max-steps", + type=int, + default=5, + help="Maximum number of steps before giving up", + ) + parser.add_argument( + "--num-retries", + type=int, + default=5, + help="Number of retries for malformed LLM output", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + agent = TAOAgent() + + with handler(provider), handler(RetryLLMHandler(num_retries=args.num_retries)): + answer = agent.run( + "How many tennis balls would fill an Olympic swimming pool?", + max_steps=args.max_steps, + ) + print("Answer:", answer) diff --git a/docs/source/llm_examples/text2sql.py b/docs/source/llm_examples/text2sql.py new file mode 100644 index 000000000..5d511b5f8 --- /dev/null +++ b/docs/source/llm_examples/text2sql.py @@ -0,0 +1,170 @@ +"""Natural language to SQL with LLM-powered debug loop. + +Demonstrates: +- Generating SQL from natural language using ``@Template.define`` +- Executing SQL against a real SQLite database +- Feeding execution errors back to the LLM for iterative fixing +- ``@Tool.define`` to expose the database schema as a tool +""" + +import argparse +import os +import sqlite3 +import textwrap + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# In-memory database setup +# --------------------------------------------------------------------------- + + +def create_sample_db() -> sqlite3.Connection: + """Create a sample SQLite database with employee data.""" + conn = sqlite3.connect(":memory:") + conn.executescript( + textwrap.dedent("""\ + CREATE TABLE departments ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + budget REAL NOT NULL + ); + CREATE TABLE employees ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + department_id INTEGER REFERENCES departments(id), + salary REAL NOT NULL, + hire_date TEXT NOT NULL + ); + INSERT INTO departments VALUES (1, 'Engineering', 500000); + INSERT INTO departments VALUES (2, 'Marketing', 200000); + INSERT INTO departments VALUES (3, 'Sales', 300000); + INSERT INTO employees VALUES (1, 'Alice', 1, 120000, '2020-01-15'); + INSERT INTO employees VALUES (2, 'Bob', 1, 110000, '2021-03-22'); + INSERT INTO employees VALUES (3, 'Carol', 2, 95000, '2019-07-01'); + INSERT INTO employees VALUES (4, 'Dave', 3, 105000, '2022-11-10'); + INSERT INTO employees VALUES (5, 'Eve', 1, 130000, '2018-05-20'); + INSERT INTO employees VALUES (6, 'Frank', 3, 98000, '2023-01-05'); + """) + ) + return conn + + +def get_schema(conn: sqlite3.Connection) -> str: + """Extract the schema from a SQLite database.""" + cursor = conn.execute( + "SELECT sql FROM sqlite_master WHERE type='table' ORDER BY name" + ) + return "\n\n".join(row[0] for row in cursor if row[0]) + + +# --------------------------------------------------------------------------- +# Templates +# --------------------------------------------------------------------------- + + +@Template.define +def generate_sql(question: str, db_schema: str) -> str: + """You are a SQL expert. Given this database schema: + + {db_schema} + + Write a SQLite query that answers: {question} + + Return ONLY the SQL query, no explanation. + """ + raise NotHandled + + +@Template.define +def fix_sql(question: str, db_schema: str, bad_sql: str, error: str) -> str: + """You are a SQL expert. Your previous query had an error. + + Database schema: + {db_schema} + + Original question: {question} + Failed SQL: {bad_sql} + Error: {error} + + Write a corrected SQLite query. Return ONLY the SQL query. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Text-to-SQL agent with debug loop +# --------------------------------------------------------------------------- + + +def text_to_sql( + conn: sqlite3.Connection, question: str, max_retries: int = 3 +) -> list[tuple]: + """Convert a natural language question to SQL and execute it. + + If the query fails, feed the error back to the LLM to fix it, + up to ``max_retries`` times. + """ + schema = get_schema(conn) + sql = generate_sql(question, schema) + + for attempt in range(max_retries + 1): + # Strip markdown fences if the LLM wraps the SQL + clean_sql = sql.strip().removeprefix("```sql").removesuffix("```").strip() + print(f" [attempt {attempt + 1}] {clean_sql}") + + try: + cursor = conn.execute(clean_sql) + return cursor.fetchall() + except Exception as e: + if attempt < max_retries: + print(f" [error] {e}") + sql = fix_sql(question, schema, clean_sql, str(e)) + else: + raise + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Natural language to SQL with LLM-powered debug loop" + ) + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + conn = create_sample_db() + provider = LiteLLMProvider(model=args.model) + + questions = [ + "What is the average salary by department?", + "Who is the highest paid employee?", + "How many employees were hired after 2021?", + ] + + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + for question in questions: + print(f"\nQ: {question}") + try: + rows = text_to_sql(conn, question) + for row in rows: + print(f" => {row}") + except Exception as e: + print(f" FAILED: {e}") diff --git a/docs/source/llm_examples/thinking.py b/docs/source/llm_examples/thinking.py new file mode 100644 index 000000000..101d5cb2e --- /dev/null +++ b/docs/source/llm_examples/thinking.py @@ -0,0 +1,113 @@ +"""Chain-of-thought reasoning with structured self-loop. + +Demonstrates: +- Structured output with a ``ThoughtStep`` dataclass +- An ``Agent`` that loops until it decides it has a final answer +- The LLM sees its own prior reasoning via ``Agent.__history__`` +""" + +import argparse +import dataclasses +import os + +from effectful.handlers.llm import Agent, Template +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Structured output +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class ThoughtStep: + reasoning: str + conclusion: str + is_final: bool + + +# --------------------------------------------------------------------------- +# Chain-of-thought agent +# --------------------------------------------------------------------------- + + +class Thinker(Agent): + """Agent that reasons step-by-step until it reaches a final answer.""" + + @Template.define + def think(self, problem: str) -> ThoughtStep: + """You are solving a problem step by step. + + Problem: {problem} + + Review the conversation history for any prior reasoning steps. + Continue from where you left off. Break the problem into small, + logical steps. Set is_final=true only when you have a complete, + well-supported answer. + """ + raise NotHandled + + def solve(self, problem: str, max_steps: int = 10) -> str: + """Solve a problem by iterative chain-of-thought reasoning.""" + for i in range(max_steps): + step = self.think(problem) + print(f" [step {i + 1}] {step.reasoning}") + if step.is_final: + return step.conclusion + + return step.conclusion + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Chain-of-thought reasoning agent") + parser.add_argument( + "--model", + type=str, + default="lm_studio/zai-org/glm-4.7-flash", + help="LLM model to use", + ) + parser.add_argument( + "--max-steps", + type=int, + default=10, + help="Maximum reasoning steps before stopping", + ) + parser.add_argument( + "--problem", + type=str, + default=( + "A farmer has 17 sheep. All but 9 run away. " + "Then he buys 5 more. How many sheep does he have now?" + ), + help="The problem to solve", + ) + args = parser.parse_args() + + if args.model.startswith("lm_studio/"): + assert os.environ.get("LM_STUDIO_API_BASE") + elif args.model.startswith("gpt-"): + assert os.environ.get("OPENAI_API_KEY") + elif args.model.startswith("claude-"): + assert os.environ.get("ANTHROPIC_API_KEY") + + provider = LiteLLMProvider(model=args.model) + + problems = [ + args.problem, + ( + "If you have a 3-gallon jug and a 5-gallon jug, " + "how do you measure exactly 4 gallons of water?" + ), + ] + + with handler(provider), handler(RetryLLMHandler(num_retries=3)): + for problem in problems: + thinker = Thinker() + print(f"\nProblem: {problem}") + answer = thinker.solve(problem, max_steps=args.max_steps) + print(f"Answer: {answer}") From 310e3db636f9409d0b149c375558ef4d20efbd22 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 25 Apr 2026 12:44:56 -0400 Subject: [PATCH 04/11] no changes to library code --- effectful/handlers/llm/completions.py | 472 +++++++------- effectful/handlers/llm/encoding.py | 903 ++++++++++++++------------ effectful/handlers/llm/evaluation.py | 3 +- effectful/handlers/llm/template.py | 223 ++++--- effectful/handlers/numpyro.py | 8 + effectful/internals/unification.py | 80 +++ effectful/ops/semantics.py | 3 +- effectful/ops/types.py | 32 +- 8 files changed, 944 insertions(+), 780 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index f2fd85f22..2e169d932 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -1,8 +1,10 @@ +import abc import collections import collections.abc import dataclasses import functools import inspect +import json import string import textwrap import traceback @@ -11,22 +13,32 @@ import litellm import pydantic +import tenacity from litellm import ( ChatCompletionFunctionMessage, ChatCompletionMessageToolCall, ChatCompletionTextObject, ChatCompletionToolMessage, - ChatCompletionToolParam, OpenAIChatCompletionAssistantMessage, OpenAIChatCompletionSystemMessage, OpenAIChatCompletionUserMessage, OpenAIMessageContentListBlock, ) -from effectful.handlers.llm.encoding import Encodable -from effectful.handlers.llm.template import Template, Tool +from effectful.handlers.llm.encoding import ( + DecodedToolCall, + Encodable, + to_content_blocks, +) +from effectful.handlers.llm.template import ( + Agent, + Template, + Tool, + _is_recursive_signature, +) +from effectful.internals.unification import nested_type from effectful.ops.semantics import fwd, handler -from effectful.ops.syntax import ObjectInterpretation, defop, implements +from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import Operation @@ -52,14 +64,29 @@ class UserMessage(OpenAIChatCompletionUserMessage): Message = AssistantMessage | ToolMessage | FunctionMessage | SystemMessage | UserMessage +DEFAULT_SYSTEM_PROMPT = ( + "You are a helpful assistant, you need to follow user's instruction" +) + + +class _NoActiveHistoryException(Exception): + """Raised when there is no active message history to append to.""" -@defop -def get_message_sequence() -> collections.OrderedDict[str, Message]: - return collections.OrderedDict() +@Operation.define +def _get_history() -> collections.OrderedDict[str, Message]: + raise _NoActiveHistoryException( + "No active message history. This operation should only be used within a handler that provides a message history." + ) -def append_message(message: Message): - get_message_sequence()[message["id"]] = message + +def append_message(message: Message, last: bool = True) -> None: + try: + _get_history()[message["id"]] = message + if not last: + _get_history().move_to_end(message["id"], last=False) + except _NoActiveHistoryException: + pass def _make_message(content: dict) -> Message: @@ -68,20 +95,27 @@ def _make_message(content: dict) -> Message: return message -type ToolCallID = str +class DecodingError[E: Exception](abc.ABC, Exception): + """Base class for decoding errors that can occur during LLM response processing.""" + + original_error: E + + @abc.abstractmethod + def to_feedback_message(self, include_traceback: bool) -> Message: + """Convert the decoding error into a feedback message to be sent back to the LLM.""" + raise NotImplementedError @dataclasses.dataclass -class ToolCallDecodingError(Exception): +class ToolCallDecodingError[E: Exception](DecodingError[E]): """Error raised when decoding a tool call fails.""" - tool_name: str - tool_call_id: str - original_error: Exception + original_error: E raw_message: Message + raw_tool_call: ChatCompletionMessageToolCall def __str__(self) -> str: - return f"Error decoding tool call '{self.tool_name}': {self.original_error}. Please provide a valid response and try again." + return f"Error decoding tool call '{self.raw_tool_call.function.name}': {self.original_error}. Please provide a valid response and try again." def to_feedback_message(self, include_traceback: bool) -> Message: error_message = f"{self}" @@ -91,17 +125,17 @@ def to_feedback_message(self, include_traceback: bool) -> Message: return _make_message( { "role": "tool", - "tool_call_id": self.tool_call_id, + "tool_call_id": self.raw_tool_call.id, "content": error_message, }, ) @dataclasses.dataclass -class ResultDecodingError(Exception): +class ResultDecodingError[E: Exception](DecodingError[E]): """Error raised when decoding the LLM response result fails.""" - original_error: Exception + original_error: E raw_message: Message def __str__(self) -> str: @@ -118,15 +152,14 @@ def to_feedback_message(self, include_traceback: bool) -> Message: @dataclasses.dataclass -class ToolCallExecutionError(Exception): +class ToolCallExecutionError[E: Exception, T](DecodingError[E]): """Error raised when a tool execution fails at runtime.""" - tool_name: str - tool_call_id: str - original_error: BaseException + original_error: E + raw_tool_call: DecodedToolCall[T] def __str__(self) -> str: - return f"Tool execution failed: Error executing tool '{self.tool_name}': {self.original_error}" + return f"Tool execution failed: Error executing tool '{self.raw_tool_call.name}': {self.original_error}" def to_feedback_message(self, include_traceback: bool) -> Message: error_message = f"{self}" @@ -136,97 +169,44 @@ def to_feedback_message(self, include_traceback: bool) -> Message: return _make_message( { "role": "tool", - "tool_call_id": self.tool_call_id, + "tool_call_id": self.raw_tool_call.id, "content": error_message, }, ) -class DecodedToolCall[T](typing.NamedTuple): - tool: Tool[..., T] - bound_args: inspect.BoundArguments - id: ToolCallID - - type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None] -@functools.cache -def _param_model(tool: Tool) -> type[pydantic.BaseModel]: - sig = inspect.signature(tool) - return pydantic.create_model( - "Params", - __config__={"extra": "forbid"}, - **{ - name: Encodable.define(param.annotation).enc - for name, param in sig.parameters.items() - }, # type: ignore - ) - - -@functools.cache -def _function_model(tool: Tool) -> ChatCompletionToolParam: - response_format = litellm.utils.type_to_response_format_param(_param_model(tool)) - assert response_format is not None - assert tool.__default__.__doc__ is not None - return { - "type": "function", - "function": { - "name": tool.__name__, - "description": textwrap.dedent(tool.__default__.__doc__), - "parameters": response_format["json_schema"]["schema"], - "strict": True, - }, - } - - -def decode_tool_call( - tool_call: ChatCompletionMessageToolCall, - tools: collections.abc.Mapping[str, Tool], - raw_message: Message, -) -> DecodedToolCall: - """Decode a tool call from the LLM response into a DecodedToolCall. - - Args: - tool_call: The tool call to decode. - tools: Mapping of tool names to Tool objects. - raw_message: Optional raw assistant message for error context. - - Raises: - ToolCallDecodingError: If the tool call cannot be decoded. - """ - tool_name = tool_call.function.name - assert tool_name is not None - - try: - tool = tools[tool_name] - except KeyError as e: - raise ToolCallDecodingError( - tool_name, tool_call.id, e, raw_message=raw_message - ) from e - - json_str = tool_call.function.arguments - sig = inspect.signature(tool) - - try: - # build dict of raw encodable types U - raw_args = _param_model(tool).model_validate_json(json_str) - - # use encoders to decode Us to python types T - bound_sig: inspect.BoundArguments = sig.bind( - **{ - param_name: Encodable.define( - sig.parameters[param_name].annotation, {} - ).decode(getattr(raw_args, param_name)) - for param_name in raw_args.model_fields_set - } - ) - except (pydantic.ValidationError, TypeError, ValueError, SyntaxError) as e: - raise ToolCallDecodingError( - tool_name, tool_call.id, e, raw_message=raw_message - ) from e - - return DecodedToolCall(tool, bound_sig, tool_call.id) +def _collect_tools( + env: collections.abc.Mapping[str, typing.Any], +) -> collections.abc.Mapping[str, Tool]: + """Operations and Templates available as tools. Auto-capture from lexical context.""" + result = {} + + for name, obj in env.items(): + # Collect tools directly in context + if isinstance(obj, Tool | Template): + result[name] = obj + + # Collect tools as methods on Agent instances in context + elif isinstance(obj, Agent): + for cls in type(obj).__mro__: + for attr_name in vars(cls): + if isinstance(getattr(obj, attr_name), Tool): + result[f"{name}__{attr_name}"] = getattr(obj, attr_name) + + # The same Tool can appear under multiple names when it is both + # visible in the enclosing scope *and* discovered via an Agent + # instance's MRO. Since Tools are hashable Operations and + # instance-method Tools are cached per instance, we keep only + # the last name for each unique tool object. + tool2name = {tool: name for name, tool in sorted(result.items())} + for name, tool in tuple(result.items()): + if tool2name[tool] != name: + del result[name] + + return result @Operation.define @@ -241,10 +221,14 @@ def completion(*args, **kwargs) -> typing.Any: return litellm.completion(*args, **kwargs) +class _BoxedResponse[T](pydantic.BaseModel): + value: T + + @Operation.define -def call_assistant[T, U]( - tools: collections.abc.Mapping[str, Tool], - response_format: Encodable[T, U], +def call_assistant[T]( + env: collections.abc.Mapping[str, typing.Any], + response_type: type[T], model: str, **kwargs, ) -> MessageResult[T]: @@ -259,20 +243,28 @@ def call_assistant[T, U]( ResultDecodingError: If the result cannot be decoded. The error includes the raw assistant message for retry handling. """ - tool_specs = {k: _function_model(t) for k, t in tools.items()} - response_model = ( - response_format.enc - if issubclass(response_format.enc, pydantic.BaseModel) - else pydantic.create_model( - "Response", value=response_format.enc, __config__={"extra": "forbid"} - ) + tools = _collect_tools(env) + tool_specs = { + k: typing.cast( + pydantic.TypeAdapter[typing.Any], + pydantic.TypeAdapter(Encodable[type(t)]), # type: ignore[misc] + ).dump_python(t, mode="json", context={k: t}) + for k, t in tools.items() + } + + # The OpenAI API requires a wrapper object for non-object structured output types, + # so we create one on the fly here. Using a Pydantic model offloads JSON schema + # generation and validation logic to litellm, and offers better error messages. + response_format: type[_BoxedResponse[T]] = pydantic.create_model( + "BoxedResponse", + value=Encodable[response_type], # type: ignore[valid-type] + __base__=_BoxedResponse, ) - messages = list(get_message_sequence().values()) response: litellm.types.utils.ModelResponse = completion( model, - messages=list(messages), - response_format=response_model if response_format.enc is not str else None, + messages=list(_get_history().values()), + response_format=None if response_type is str else response_format, tools=list(tool_specs.values()), **kwargs, ) @@ -286,35 +278,35 @@ def call_assistant[T, U]( append_message(raw_message) tool_calls: list[DecodedToolCall] = [] - raw_tool_calls = message.get("tool_calls") or [] - for raw_tool_call in raw_tool_calls: - validated_tool_call = ChatCompletionMessageToolCall.model_validate( - raw_tool_call - ) - decoded_tool_call = decode_tool_call(validated_tool_call, tools, raw_message) - tool_calls.append(decoded_tool_call) + encoding: pydantic.TypeAdapter[DecodedToolCall] = pydantic.TypeAdapter( + Encodable[DecodedToolCall] + ) + for raw_tool_call in message.get("tool_calls") or []: + try: + tool_calls += [encoding.validate_python(raw_tool_call, context=tools)] + except Exception as e: + raise ToolCallDecodingError( + raw_tool_call=raw_tool_call, + original_error=e, + raw_message=raw_message, + ) from e result = None - if not tool_calls and response_format.enc is not str: + if not tool_calls: # return response serialized_result = message.get("content") or message.get("reasoning_content") assert isinstance(serialized_result, str), ( "final response from the model should be a string" ) - try: - raw_result = response_model.model_validate_json(serialized_result) - result = response_format.decode( - raw_result.value - if not issubclass(response_format.enc, pydantic.BaseModel) - else raw_result - ) # type: ignore - except (pydantic.ValidationError, TypeError, ValueError, SyntaxError) as e: - raise ResultDecodingError(e, raw_message=raw_message) from e - elif not tool_calls and response_format.enc is str: - # if expecting a string result, return the raw content as the result - content = message.get("content") or message.get("reasoning_content") - assert isinstance(content, str), "Expected content to be a string" - result = content + if response_type is str: + result = typing.cast(T, serialized_result) + else: + try: + result = response_format.model_validate( + json.loads(serialized_result), context=env + ).value + except Exception as e: + raise ResultDecodingError(e, raw_message=raw_message) from e return (raw_message, tool_calls, result) @@ -327,16 +319,19 @@ def call_tool(tool_call: DecodedToolCall) -> Message: """ # call tool with python types - # call_tool invariant: tool is called in a context with a fresh message sequence - message_sequence: collections.OrderedDict[str, Message] = collections.OrderedDict() - with handler({get_message_sequence: lambda: message_sequence}): + try: result = tool_call.tool( *tool_call.bound_args.args, **tool_call.bound_args.kwargs ) + except Exception as e: + raise ToolCallExecutionError(raw_tool_call=tool_call, original_error=e) from e - # serialize back to U using encoder for return type - return_type = Encodable.define(type(result)) - encoded_result = return_type.serialize(return_type.encode(result)) + return_type: pydantic.TypeAdapter[typing.Any] = pydantic.TypeAdapter( + Encodable[nested_type(result).value] # type: ignore[misc] + ) + encoded_result = to_content_blocks( + return_type.dump_python(result, mode="json", context={}) + ) message = _make_message( dict(role="tool", content=encoded_result, tool_call_id=tool_call.id), ) @@ -372,11 +367,11 @@ def flush_text() -> None: continue obj, _ = formatter.get_field(field_name, (), env) - encoder = Encodable.define(type(obj)) - encoded_obj: typing.Sequence[OpenAIMessageContentListBlock] = encoder.serialize( - encoder.encode(obj) + encoder: pydantic.TypeAdapter[typing.Any] = pydantic.TypeAdapter( + Encodable[nested_type(obj).value] # type: ignore[misc] ) - for part in encoded_obj: + encoded_obj = encoder.dump_python(obj, mode="json", context=env) + for part in to_content_blocks(encoded_obj): if part["type"] == "text": text = ( formatter.convert_field(part["text"], conversion) @@ -398,38 +393,12 @@ def flush_text() -> None: @Operation.define -def call_system(template: Template) -> collections.abc.Sequence[Message]: +def call_system(template: Template) -> Message: """Get system instruction message(s) to prepend to all LLM prompts.""" - - assert inspect.getdoc(type(template)) is not None - - system_prompt = inspect.cleandoc(f""" - You are responsible for implementing the `Template` '{template.__name__}' defined in the module source code below. - - First, as background, here is the class-level documentation for the `Template` class:: - - {inspect.getdoc(type(template))} - """) - - try: - system_prompt += inspect.cleandoc(f""" - Here is the source code of the module defining the `Template` instance '{template.__name__}':: - - {inspect.getsource(inspect.getmodule(template))} - """) - except (TypeError, OSError): - system_prompt += inspect.cleandoc(f""" - The source code for the module defining '{template.__name__}' is not available. - Instead, here are the signature and docstring of '{template.__name__}':: - - {template.__name__} :: {template.__signature__.format()} - - {inspect.cleandoc(template.__prompt_template__)} - """) - - msg = _make_message(dict(role="system", content=system_prompt)) - append_message(msg) - return (msg,) + system_prompt = template.__system_prompt__ or DEFAULT_SYSTEM_PROMPT + message = _make_message(dict(role="system", content=system_prompt)) + append_message(message, last=False) + return message class RetryLLMHandler(ObjectInterpretation): @@ -444,69 +413,68 @@ class RetryLLMHandler(ObjectInterpretation): captured and returned as tool response messages. Args: - num_retries: The maximum number of retries (default: 3). include_traceback: If True, include full traceback in error feedback - for better debugging context (default: False). + for better debugging context (default: True). catch_tool_errors: Exception type(s) to catch during tool execution. Can be a single exception class or a tuple of exception classes. Defaults to Exception (catches all exceptions). + stop: tenacity stop condition for retrying `call_assistant`. Defaults to + `tenacity.stop_after_attempt(4)`, which stops after 4 attempts. + **kwargs: Additional keyword arguments forwarded to `tenacity.Retrying`. """ + call_assistant_retryer: tenacity.Retrying + + _user_before_sleep: collections.abc.Callable[[tenacity.RetryCallState], None] | None + def __init__( self, - num_retries: int = 3, - include_traceback: bool = False, + include_traceback: bool = True, catch_tool_errors: type[BaseException] | tuple[type[BaseException], ...] = Exception, + stop: tenacity.stop.stop_base = tenacity.stop_after_attempt(4), + **kwargs, ): - self.num_retries = num_retries self.include_traceback = include_traceback self.catch_tool_errors = catch_tool_errors + assert "retry" not in kwargs, "Cannot override retry logic of RetryLLMHandler" + assert "reraise" not in kwargs, ( + "Cannot override reraise logic of RetryLLMHandler" + ) + self._user_before_sleep = kwargs.pop("before_sleep", None) + self.call_assistant_retryer = tenacity.Retrying( + retry=tenacity.retry_if_exception_type( + (ToolCallDecodingError, ResultDecodingError) + ), + reraise=True, + before_sleep=self._before_sleep, + stop=stop, + **kwargs, + ) + + def _before_sleep(self, retry_state: tenacity.RetryCallState) -> None: + e = retry_state.outcome.exception() # type: ignore + assert isinstance(e, (ToolCallDecodingError, ResultDecodingError)) + append_message(e.raw_message) + append_message(e.to_feedback_message(self.include_traceback)) + if self._user_before_sleep is not None: + self._user_before_sleep(retry_state) @implements(call_assistant) - def _call_assistant[T, U]( + def _call_assistant[T]( self, - tools: collections.abc.Mapping[str, Tool], - response_format: Encodable[T, U], + env: collections.abc.Mapping[str, typing.Any], + response_type: type[T], model: str, **kwargs, ) -> MessageResult[T]: - message_sequence = get_message_sequence().copy() - last_attempt = self.num_retries - - for attempt in range(self.num_retries + 1): - try: - # call assistant, use saved message_sequence - with handler({get_message_sequence: lambda: message_sequence}): - message, tool_calls, result = fwd( - tools, response_format, model, **kwargs - ) - - # Success! The returned message is the final successful response. - # Malformed messages from retries are only in local message_sequence copy, - # not in the enclosing message sequence. - append_message(message) - return (message, tool_calls, result) - - except (ToolCallDecodingError, ResultDecodingError) as e: - # On last attempt, re-raise to preserve full traceback - if attempt == last_attempt: - raise - - # Add the malformed assistant message - message_sequence[e.raw_message["id"]] = e.raw_message + _message_sequence = _get_history().copy() - # Add error feedback as a tool response - error_feedback: Message = e.to_feedback_message(self.include_traceback) - message_sequence[error_feedback["id"]] = error_feedback + with handler({_get_history: lambda: _message_sequence}): + message, tool_calls, result = self.call_assistant_retryer(fwd) - # Should never reach here - either we return on success or raise on final failure - raise AssertionError("Unreachable: retry loop exited without return or raise") - - @implements(completion) - def _completion(self, *args, **kwargs) -> typing.Any: - """Inject num_retries for litellm's built-in network error handling.""" - return fwd(*args, **({"num_retries": self.num_retries} | kwargs)) + append_message(message) + return (message, tool_calls, result) @implements(call_tool) def _call_tool(self, tool_call: DecodedToolCall) -> Message: @@ -518,11 +486,13 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: """ try: return fwd(tool_call) - except self.catch_tool_errors as e: - error = ToolCallExecutionError(tool_call.tool.__name__, tool_call.id, e) - message = error.to_feedback_message(self.include_traceback) - append_message(message) - return message + except ToolCallExecutionError as e: + if isinstance(e.original_error, self.catch_tool_errors): + message = e.to_feedback_message(self.include_traceback) + append_message(message) + return message + else: + raise class LiteLLMProvider(ObjectInterpretation): @@ -540,19 +510,25 @@ def __init__(self, model="gpt-4o", **config): def _call[**P, T]( self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs ) -> T: - message_sequence: collections.OrderedDict[str, Message] = get_message_sequence() - with handler({get_message_sequence: lambda: message_sequence}): - # encode arguments - bound_args = inspect.signature(template).bind(*args, **kwargs) - bound_args.apply_defaults() - env = template.__context__.new_child(bound_args.arguments) - - # Create response_model with env so tools passed as arguments are available - response_model = Encodable.define( - template.__signature__.return_annotation, env - ) - - call_system(template) + # encode arguments + bound_args = inspect.signature(template).bind(*args, **kwargs) + bound_args.apply_defaults() + env = template.__context__.new_child(bound_args.arguments) + + if not _is_recursive_signature(template.__signature__): + env = env.new_child({k: None for k, v in env.items() if v is template}) + + history: collections.OrderedDict[str, Message] = getattr( + template, "__history__", collections.OrderedDict() + ) # type: ignore + history_copy = history.copy() + + with handler({_get_history: lambda: history_copy}): + if ( + not _get_history() + or next(iter(_get_history().values()))["role"] != "system" + ): + call_system(template) message: Message = call_user(template.__prompt_template__, env) @@ -561,12 +537,14 @@ def _call[**P, T]( result: T | None = None while message["role"] != "assistant" or tool_calls: message, tool_calls, result = call_assistant( - template.tools, response_model, **self.config + env, template.__signature__.return_annotation, **self.config ) for tool_call in tool_calls: message = call_tool(tool_call) - assert result is not None, ( - "call_assistant did not produce a result nor tool_calls" - ) - return result + try: + _get_history() + except _NoActiveHistoryException: + history.clear() + history.update(history_copy) + return typing.cast(T, result) diff --git a/effectful/handlers/llm/encoding.py b/effectful/handlers/llm/encoding.py index 66edf6add..cfbb08aec 100644 --- a/effectful/handlers/llm/encoding.py +++ b/effectful/handlers/llm/encoding.py @@ -1,286 +1,347 @@ import ast import base64 +import dataclasses +import functools import inspect import io +import json import textwrap import types import typing -from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping, MutableMapping, Sequence -from dataclasses import dataclass -from types import CodeType +from collections.abc import ( + Callable, + Mapping, + MutableMapping, +) from typing import Any +import litellm import pydantic from litellm import ( - ChatCompletionImageUrlObject, + ChatCompletionImageObject, + ChatCompletionMessageToolCall, + ChatCompletionTextObject, + ChatCompletionToolParam, OpenAIMessageContentListBlock, ) +from openai.lib._pydantic import _ensure_strict_json_schema +from openai.types.chat import ( + ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall, +) from PIL import Image import effectful.handlers.llm.evaluation as evaluation -from effectful.ops.semantics import _simple_type -from effectful.ops.syntax import _CustomSingleDispatchCallable +from effectful.handlers.llm.template import Tool +from effectful.internals.unification import GenericAlias, TypeEvaluator, nested_type from effectful.ops.types import Operation, Term +type ToolCallID = str + +CONTENT_BLOCK_TYPES: frozenset[str] = frozenset( + literal + for member in typing.get_args(OpenAIMessageContentListBlock) + for literal in typing.get_args(typing.get_type_hints(member).get("type", str)) + if isinstance(literal, str) +) -def _pil_image_to_base64_data(pil_image: Image.Image) -> str: - buf = io.BytesIO() - pil_image.save(buf, format="PNG") - return base64.b64encode(buf.getvalue()).decode("utf-8") +@pydantic.validate_call(validate_return=True) +def to_content_blocks(value: typing.Any) -> list[OpenAIMessageContentListBlock]: + """Convert an encoded JSON-compatible value into a flat list of content blocks. -def _pil_image_to_base64_data_uri(pil_image: Image.Image) -> str: - return f"data:image/png;base64,{_pil_image_to_base64_data(pil_image)}" + Walks the value tree, extracting content-block-shaped dicts (identified by + their ``type`` discriminator) and emitting JSON syntax as text around them. + Top-level strings are emitted bare (for natural template rendering). + Inside JSON structures, separators match ``json.dumps`` defaults so that + the linearization law holds for non-string encoded values: + ``linearize(to_content_blocks(v)) == json.dumps(v)``. + """ + if isinstance(value, str): + return [ChatCompletionTextObject(type="text", text=value)] + + buf: list[str] = [] + blocks: list[OpenAIMessageContentListBlock] = [] + + def flush() -> None: + if buf: + blocks.append(ChatCompletionTextObject(type="text", text="".join(buf))) + buf.clear() + + def walk(v: typing.Any) -> None: + if isinstance(v, dict) and v.get("type") in CONTENT_BLOCK_TYPES: + flush() + blocks.append(typing.cast(OpenAIMessageContentListBlock, v)) + elif isinstance(v, dict): + buf.append("{") + for i, (k, val) in enumerate(v.items()): + if i: + buf.append(", ") + buf.append(json.dumps(k) + ": ") + walk(val) + buf.append("}") + elif isinstance(v, list): + buf.append("[") + for i, item in enumerate(v): + if i: + buf.append(", ") + walk(item) + buf.append("]") + else: + buf.append(json.dumps(v)) + + walk(value) + flush() + return blocks -class Encodable[T, U](ABC): - base: type[T] - enc: type[U] - ctx: Mapping[str, Any] - @abstractmethod - def encode(self, value: T) -> U: - raise NotImplementedError +@dataclasses.dataclass(frozen=True, eq=True) +class DecodedToolCall[T]: + """ + Structured representation of a tool call decoded from an LLM response. + """ - @abstractmethod - def decode(self, encoded_value: U) -> T: - raise NotImplementedError + tool: Tool[..., T] + bound_args: inspect.BoundArguments + id: ToolCallID + name: str - @abstractmethod - def serialize(self, encoded_value: U) -> Sequence[OpenAIMessageContentListBlock]: - raise NotImplementedError - # serialize and deserialize have different types reflecting the LLM api chat.completions(list[content]) -> str - @abstractmethod - def deserialize(self, serialized_value: str) -> U: - raise NotImplementedError +if typing.TYPE_CHECKING: + type Encodable[T] = typing.Annotated[T, "encoded"] +else: + + class Encodable: + def __class_getitem__(cls, item): + return TypeToPydanticType().evaluate(item) + + +class TypeToPydanticType(TypeEvaluator): + """Substitute custom types with their Pydantic Annotated equivalents. + + Recursively walks a type annotation tree, replacing leaf types that have + registered Pydantic annotations (e.g., Image.Image -> PydanticImage) and + reconstructing the full generic type. + + The result can be passed to pydantic.TypeAdapter() for automatic + validation and serialization of nested structures. + """ - @typing.final @staticmethod - @_CustomSingleDispatchCallable - def define( - __dispatch: Callable[ - [type[T]], Callable[[type[T], Mapping[str, Any] | None], "Encodable[T, U]"] - ], - t: type[T], - ctx: Mapping[str, Any] | None = None, - ) -> "Encodable[T, U]": - dispatch_ty = _simple_type(t) - return __dispatch(dispatch_ty)(t, ctx) + @functools.singledispatch + def _registry(ty: type): + raise RuntimeError("should not be here!") + + @classmethod + def register(cls, *args, **kwargs): + return cls._registry.register(*args, **kwargs) + + def evaluate(self, ty): + app = super().evaluate(ty) + origin = typing.get_origin(app) + # Only dispatch on regular types. Special forms (Literal, Annotated, + # Union) have non-type origins that singledispatch can't resolve; pass + # them through for Pydantic to handle natively. + if isinstance(app, type | GenericAlias) and ( + origin is None or isinstance(origin, type) + ): + return self._registry.dispatch(origin or app)(app) + else: + return app -@dataclass -class BaseEncodable[T](Encodable[T, T]): - base: type[T] - enc: type[T] - ctx: Mapping[str, Any] - adapter: pydantic.TypeAdapter[T] +@TypeToPydanticType.register(str) +def _pydantic_type_str[T](ty: type[T]) -> type[T]: + return ty - def encode(self, value: T) -> T: - return typing.cast(T, self.adapter.validate_python(value)) - def decode(self, encoded_value: T) -> T: - return typing.cast(T, self.adapter.validate_python(encoded_value)) +@TypeToPydanticType.register(object) +def _pydantic_type_base(ty: type) -> Any: + return ty - def serialize(self, encoded_value: T) -> Sequence[OpenAIMessageContentListBlock]: - json_str = self.adapter.dump_json(encoded_value).decode("utf-8") - return [{"type": "text", "text": json_str}] - def deserialize(self, serialized_value: str) -> T: - # Parse JSON string into the encoded value, validated as `ty`. - return typing.cast(T, self.adapter.validate_json(serialized_value)) +class _ComplexModel(typing.TypedDict): + real: float + imag: float -@dataclass -class StrEncodable(Encodable[str, str]): - base: type[str] - enc: type[str] - ctx: Mapping[str, Any] +@pydantic.validate_call(validate_return=True) +def _validate_complex(value: _ComplexModel) -> complex: + return complex(value["real"], value["imag"]) - def encode(self, value: str) -> str: - return value - def decode(self, encoded_value: str) -> str: - return encoded_value +@pydantic.validate_call(validate_return=True) +def _serialize_complex(value: complex) -> _ComplexModel: + return {"real": value.real, "imag": value.imag} - def serialize(self, encoded_value: str) -> Sequence[OpenAIMessageContentListBlock]: - # Serialize strings without JSON encoding (no extra quotes) - return [{"type": "text", "text": encoded_value}] - def deserialize(self, serialized_value: str) -> str: - return serialized_value +@TypeToPydanticType.register(complex) +def _pydantic_type_complex(ty): + """Encode ``complex`` as ``{"real": float, "imag": float}``.""" + adapted_schema = pydantic.TypeAdapter(_ComplexModel).json_schema() -@dataclass -class PydanticBaseModelEncodable[T: pydantic.BaseModel](Encodable[T, T]): - base: type[T] - enc: type[T] - ctx: Mapping[str, Any] + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_complex), + pydantic.PlainSerializer(_serialize_complex), + pydantic.WithJsonSchema({**adapted_schema, "additionalProperties": False}), + ] - def decode(self, encoded_value: T) -> T: - return encoded_value - def encode(self, value: T) -> T: - return value +def _inline_refs(schema: dict) -> dict: + """Inline ``$ref`` pointers so ``WithJsonSchema`` never emits orphan refs. - def serialize(self, encoded_value: T) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] + Workaround for https://github.com/pydantic/pydantic/issues/12145 — + Pydantic's ``GenerateJsonSchema`` does not merge user-provided ``$defs`` + into its internal ref map, so any ``$ref`` in a ``WithJsonSchema`` value + causes a ``KeyError`` when the annotated type is composed into a model. + """ + defs = schema.get("$defs", {}) - def deserialize(self, serialized_value: str) -> T: - return typing.cast(T, self.base.model_validate_json(serialized_value)) + def _resolve(obj): + if isinstance(obj, dict): + if "$ref" in obj: + ref_name = obj["$ref"].split("/")[-1] + if ref_name in defs: + return _resolve(defs[ref_name]) + return {k: _resolve(v) for k, v in obj.items() if k != "$defs"} + if isinstance(obj, list): + return [_resolve(item) for item in obj] + return obj + return _resolve(schema) -@dataclass -class ImageEncodable(Encodable[Image.Image, ChatCompletionImageUrlObject]): - base: type[Image.Image] - enc: type[ChatCompletionImageUrlObject] - ctx: Mapping[str, Any] - def encode(self, value: Image.Image) -> ChatCompletionImageUrlObject: - return { - "detail": "auto", - "url": _pil_image_to_base64_data_uri(value), - } +@TypeToPydanticType.register(tuple) +def _pydantic_type_tuple(ty): + """Convert finitary tuples to object-based schemas (``properties/required``). - def decode(self, encoded_value: ChatCompletionImageUrlObject) -> Image.Image: - image_url = encoded_value["url"] - if not image_url.startswith("data:image/"): - raise RuntimeError( - f"expected base64 encoded image as data uri, received {image_url}" - ) - data = image_url.split(",")[1] - return Image.open(fp=io.BytesIO(base64.b64decode(data))) - - def serialize( - self, encoded_value: ChatCompletionImageUrlObject - ) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "image_url", "image_url": encoded_value}] - - def deserialize(self, serialized_value: str) -> ChatCompletionImageUrlObject: - # Images are serialized as image_url blocks, not text - # This shouldn't be called in normal flow, but provide a fallback - raise NotImplementedError("Image deserialization from string is not supported") - - -@dataclass -class TupleEncodable[T](Encodable[T, typing.Any]): - base: type[T] - enc: type[typing.Any] - ctx: Mapping[str, Any] - has_image: bool - element_encoders: list[Encodable] - - def encode(self, value: T) -> typing.Any: - if not isinstance(value, tuple): - raise TypeError(f"Expected tuple, got {type(value)}") - if len(value) != len(self.element_encoders): - raise ValueError( - f"Tuple length {len(value)} does not match expected length {len(self.element_encoders)}" - ) - return tuple( - [enc.encode(elem) for enc, elem in zip(self.element_encoders, value)] + OpenAI's strict mode rejects the ``prefixItems`` array schema that Pydantic + emits for fixed-length tuples. We convert them to a Pydantic model with + positional ``item_0``, ``item_1``, … fields instead. + + NamedTuples are handled similarly using their field names. + Bare ``tuple`` and variadic ``tuple[T, ...]`` are passed through unchanged. + """ + # NamedTuple subclasses dispatch here via MRO; use field names. + if isinstance(ty, type) and hasattr(ty, "_fields"): + hints = typing.get_type_hints(ty) + nt_fields: list[str] = list(ty._fields) + nt_types = [hints.get(f, typing.Any) for f in nt_fields] + nt_adapters = [pydantic.TypeAdapter(t) for t in nt_types] + nt_model = pydantic.create_model( + ty.__name__, + __config__={"extra": "forbid"}, + **{f: (t, ...) for f, t in zip(nt_fields, nt_types)}, ) - def decode(self, encoded_value: typing.Any) -> T: - if len(encoded_value) != len(self.element_encoders): - raise ValueError( - f"tuple length {len(encoded_value)} does not match expected length {len(self.element_encoders)}" + def _nt_validate(value, info: pydantic.ValidationInfo): + if isinstance(value, tuple | list): + value = dict(zip(nt_fields, value)) + return ty( + **{ + f: nt_adapters[i].validate_python(value[f], context=info.context) + for i, f in enumerate(nt_fields) + } ) - decoded_elements: list[typing.Any] = [ - enc.decode(elem) for enc, elem in zip(self.element_encoders, encoded_value) - ] - return typing.cast(T, tuple(decoded_elements)) - - def serialize( - self, encoded_value: typing.Any - ) -> Sequence[OpenAIMessageContentListBlock]: - if self.has_image: - # If tuple contains images, serialize each element and flatten the results - result: list[OpenAIMessageContentListBlock] = [] - if not isinstance(encoded_value, tuple): - raise TypeError(f"Expected tuple, got {type(encoded_value)}") - if len(encoded_value) != len(self.element_encoders): - raise ValueError( - f"Tuple length {len(encoded_value)} does not match expected length {len(self.element_encoders)}" + + def _nt_serialize(value, info: pydantic.SerializationInfo): + return { + f: nt_adapters[i].dump_python( + getattr(value, f), mode="json", context=info.context ) - for enc, elem in zip(self.element_encoders, encoded_value): - result.extend(enc.serialize(elem)) - return result - else: - # Use base serialization for non-image tuples - adapter: pydantic.TypeAdapter[tuple] = pydantic.TypeAdapter(self.enc) - json_str = adapter.dump_json(encoded_value).decode("utf-8") - return [{"type": "text", "text": json_str}] - - def deserialize(self, serialized_value: str) -> typing.Any: - adapter: pydantic.TypeAdapter[tuple] = pydantic.TypeAdapter(self.enc) - return typing.cast(typing.Any, adapter.validate_json(serialized_value)) - - -@dataclass -class ListEncodable[T](Encodable[list[T], typing.Any]): - base: type[list[T]] - enc: type[typing.Any] - ctx: Mapping[str, Any] - has_image: bool - element_encoder: Encodable[T, typing.Any] - - def encode(self, value: list[T]) -> typing.Any: - if not isinstance(value, list): - raise TypeError(f"Expected list, got {type(value)}") - return [self.element_encoder.encode(elem) for elem in value] - - def decode(self, encoded_value: typing.Any) -> list[T]: - decoded_elements: list[T] = [ - self.element_encoder.decode(elem) for elem in encoded_value + for i, f in enumerate(nt_fields) + } + + return typing.Annotated[ + ty, + pydantic.PlainValidator(_nt_validate), + pydantic.PlainSerializer(_nt_serialize), + pydantic.WithJsonSchema(_inline_refs(nt_model.model_json_schema())), ] - return typing.cast(list[T], decoded_elements) - - def serialize( - self, encoded_value: typing.Any - ) -> Sequence[OpenAIMessageContentListBlock]: - if self.has_image: - # If list contains images, serialize each element and flatten the results - result: list[OpenAIMessageContentListBlock] = [] - if not isinstance(encoded_value, list): - raise TypeError(f"Expected list, got {type(encoded_value)}") - for elem in encoded_value: - result.extend(self.element_encoder.serialize(elem)) - return result - else: - # Use base serialization for non-image lists - adapter = pydantic.TypeAdapter(self.enc) - json_str = adapter.dump_json(encoded_value).decode("utf-8") - return [{"type": "text", "text": json_str}] - def deserialize(self, serialized_value: str) -> typing.Any: - adapter = pydantic.TypeAdapter(self.enc) - return typing.cast(typing.Any, adapter.validate_json(serialized_value)) + args = typing.get_args(ty) + # Bare tuple or tuple[T, ...] — Pydantic's native handling is fine. + # Note: tuple[()] also has get_args() == (), but has origin=tuple. + if (not args and typing.get_origin(ty) is None) or ( + len(args) == 2 and args[1] is Ellipsis + ): + return ty -def _format_callable_type(callable_type: type[Callable]) -> str: - """Format a Callable type annotation as a string for LLM instructions.""" - args = typing.get_args(callable_type) - if not args: - return "Callable" + # tuple[()] (empty args with origin) maps to zero fields; otherwise use args. + effective: list[typing.Any] = list(args) - # Callable[[arg1, arg2, ...], return_type] - if len(args) >= 2: - param_types = args[0] - return_type = args[-1] + adapters = [pydantic.TypeAdapter(a) for a in effective] - if param_types is ...: - params_str = "..." - elif isinstance(param_types, list | tuple): - params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types) - else: - params_str = str(param_types) + model = pydantic.create_model( + "TupleItems", + __config__={"extra": "forbid"}, + **{f"item_{i}": (a, ...) for i, a in enumerate(effective)}, + ) - return_str = getattr(return_type, "__name__", str(return_type)) - return f"Callable[[{params_str}], {return_str}]" + def _validate(value, info: pydantic.ValidationInfo): + if isinstance(value, tuple | list): + value = {f"item_{i}": v for i, v in enumerate(value)} + return tuple( + adapters[i].validate_python(value[f"item_{i}"], context=info.context) + for i in range(len(effective)) + ) + + def _serialize(value, info: pydantic.SerializationInfo): + return { + f"item_{i}": adapters[i].dump_python(v, mode="json", context=info.context) + for i, v in enumerate(value) + } - return str(callable_type) + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate), + pydantic.PlainSerializer(_serialize), + pydantic.WithJsonSchema(_inline_refs(model.model_json_schema())), + ] + + +@TypeToPydanticType.register(Term) +def _pydantic_type_term(ty: type[Term]): + raise TypeError("Terms cannot be converted to Pydantic types.") + + +@TypeToPydanticType.register(Operation) +def _pydantic_type_operation(ty: type[Operation]): + raise TypeError("Operations cannot be converted to Pydantic types.") + + +@pydantic.validate_call(validate_return=False) +def _validate_image(value: ChatCompletionImageObject) -> Image.Image: + value = pydantic.TypeAdapter(ChatCompletionImageObject).validate_python(value) + image_url: litellm.ChatCompletionImageUrlObject | str = value["image_url"] + url: str = image_url["url"] if isinstance(image_url, dict) else image_url + prefix, data = url.split(",") + if not prefix.startswith("data:image/"): + raise ValueError(f"expected base64 encoded image as data uri, received {url}") + return Image.open(fp=io.BytesIO(base64.b64decode(data))) + + +def _serialize_image(value: Image.Image) -> ChatCompletionImageObject: + buf = io.BytesIO() + value.save(buf, format="PNG") + url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}" + return pydantic.TypeAdapter(ChatCompletionImageObject).validate_python( + {"type": "image_url", "image_url": {"detail": "auto", "url": url}} + ) + + +@TypeToPydanticType.register(Image.Image) +def _pydantic_type_image(ty: type[Image.Image]): + adapter = pydantic.TypeAdapter(ChatCompletionImageObject) + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_image), + pydantic.PlainSerializer(_serialize_image), + pydantic.WithJsonSchema(_inline_refs(adapter.json_schema())), + ] class SynthesizedFunction(pydantic.BaseModel): @@ -303,7 +364,24 @@ def _create_typed_synthesized_function( Uses pydantic.create_model to ensure the description is included in the JSON schema sent to the LLM, informing it of the expected function signature. """ - type_signature = _format_callable_type(callable_type) + if not typing.get_args(callable_type): + type_signature = "Callable" + # Callable[[arg1, arg2, ...], return_type] + elif len(typing.get_args(callable_type)) >= 2: + param_types = typing.get_args(callable_type)[0] + return_type = typing.get_args(callable_type)[-1] + + if param_types is ...: + params_str = "..." + elif isinstance(param_types, list | tuple): + params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types) + else: + params_str = str(param_types) + + return_str = getattr(return_type, "__name__", str(return_type)) + type_signature = f"Callable[[{params_str}], {return_str}]" + else: + type_signature = str(callable_type) description = f"""Given the specification above, generate a Python function satisfying the following specification and type signature. @@ -367,70 +445,56 @@ def _validate_signature_callable( ) -@dataclass -class CallableEncodable(Encodable[Callable, SynthesizedFunction]): - base: type[Callable] - enc: type[SynthesizedFunction] - ctx: Mapping[str, Any] - expected_params: list[type] | None = None - expected_return: type | None = None # None means decode is disabled - - def encode(self, t: Callable) -> SynthesizedFunction: - # (https://github.com/python/mypy/issues/14928) - if not isinstance(t, Callable): # type: ignore - raise TypeError(f"Expected callable, got {type(t)}") +@TypeToPydanticType.register(Callable) +def _pydantic_callable(callable_type: Any) -> Any: + """Create a Pydantic-compatible Annotated type for a parameterized Callable. - try: - source = inspect.getsource(t) - except (OSError, TypeError): - source = None + Usage: PydanticCallable(Callable[[int, str], bool]) + """ + type_args = typing.get_args(callable_type) - if source: - return self.enc(module_code=textwrap.dedent(source)) - - # Source not available - create stub from name, signature, and docstring - # This is useful for builtins and C extensions - name = getattr(t, "__name__", None) - if not name: - raise RuntimeError( - f"Cannot encode callable {t}: no source code and no __name__" + if not type_args: + typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type] + expected_params = None + expected_return = None + else: + if len(type_args) < 2: + raise TypeError( + f"Callable type signature incomplete: {callable_type}. " + "Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]." ) - - try: - sig = inspect.signature(t) - sig_str = str(sig) - except (ValueError, TypeError): - # Some builtins don't have inspectable signatures - sig_str = "(...)" - - docstring = inspect.getdoc(t) - if not docstring: - raise RuntimeError( - f"Cannot encode callable {t}: no source code and no docstring" + param_types, expected_return = type_args[0], type_args[-1] + typed_enc = _create_typed_synthesized_function(callable_type) + if param_types is not ... and isinstance(param_types, list | tuple): + expected_params = list(param_types) + else: + expected_params = None + + def _validate(value: Any, info: pydantic.ValidationInfo) -> Callable: + if callable(value) and not isinstance(value, dict): + return value + if isinstance(value, SynthesizedFunction): + encoded = value + elif isinstance(value, dict): + encoded = typed_enc.model_validate(value) + elif isinstance(value, str): + encoded = typed_enc.model_validate_json(value) + else: + raise ValueError( + f"Expected callable, SynthesizedFunction dict, or JSON string, " + f"got {type(value)}" ) - # Format as a stub function with docstring - stub_code = f'''def {name}{sig_str}: - """{docstring}""" - ... -''' - return self.enc(module_code=stub_code) - - def decode(self, encoded_value: SynthesizedFunction) -> Callable: - # Decode requires a concrete return type for synthesis - if self.expected_return is None: + if expected_return is None: raise TypeError( "Cannot decode/synthesize callable without a concrete type signature. " "Use Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType] " "with a concrete return type (not Any)." ) - filename = f"" - - module_code = encoded_value.module_code - - # Parse and validate AST before execution - module: ast.AST = evaluation.parse(module_code, filename) + ctx = info.context or {} + filename = f"" + module: ast.AST = evaluation.parse(encoded.module_code, filename) if not isinstance(module, ast.Module) or not module.body: raise ValueError( @@ -444,20 +508,12 @@ def decode(self, encoded_value: SynthesizedFunction) -> Callable: f"got {type(last_stmt).__name__}" ) - # Validate signature from AST before execution - _validate_signature_ast(last_stmt, self.expected_params) - - # Type-check with mypy; pass original module_code so mypy sees exact source - evaluation.type_check( - module, self.ctx, self.expected_params, self.expected_return - ) + _validate_signature_ast(last_stmt, expected_params) + evaluation.type_check(module, ctx, expected_params, expected_return) - # Compile and execute - # https://docs.python.org/3/library/functions.html#exec g: MutableMapping[str, Any] = {} - g.update(self.ctx or {}) - - bytecode: CodeType = evaluation.compile(module, filename) + g.update(ctx) + bytecode: types.CodeType = evaluation.compile(module, filename) evaluation.exec(bytecode, g) func_name = last_stmt.name @@ -472,152 +528,159 @@ def decode(self, encoded_value: SynthesizedFunction) -> Callable: f"decode() expected '{func_name}' to be callable, got {type(result)}" ) - # Validate signature from runtime callable after execution - _validate_signature_callable(result, self.expected_params, self.expected_return) - + _validate_signature_callable(result, expected_params, expected_return) return result - def serialize( - self, encoded_value: SynthesizedFunction - ) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] - - def deserialize(self, serialized_value: str) -> SynthesizedFunction: - return SynthesizedFunction.model_validate_json(serialized_value) - - -@Encodable.define.register(object) -def _encodable_object[T, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - adapter = pydantic.TypeAdapter(ty) - ctx = {} if ctx is None else ctx - return typing.cast(Encodable[T, U], BaseEncodable(ty, ty, ctx, adapter)) - - -@Encodable.define.register(str) -def _encodable_str(ty: type[str], ctx: Mapping[str, Any] | None) -> Encodable[str, str]: - """Handler for str type that serializes without JSON encoding.""" - return StrEncodable(ty, ty, ctx or {}) - - -@Encodable.define.register(Term) -def _encodable_term[T: Term, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - raise TypeError("Terms cannot be encoded or decoded in general.") - - -@Encodable.define.register(Operation) -def _encodable_operation[T: Operation, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - raise TypeError("Operations cannot be encoded or decoded in general.") - - -@Encodable.define.register(pydantic.BaseModel) -def _encodable_pydantic_base_model[T: pydantic.BaseModel]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, T]: - return PydanticBaseModelEncodable(ty, ty, ctx or {}) - - -@Encodable.define.register(Image.Image) -def _encodable_image( - ty: type[Image.Image], ctx: Mapping[str, Any] | None -) -> Encodable[Image.Image, ChatCompletionImageUrlObject]: - return ImageEncodable(ty, ChatCompletionImageUrlObject, ctx or {}) + def _serialize(value: Callable) -> dict: + if not callable(value): + raise TypeError(f"Expected callable, got {type(value)}") + try: + source = inspect.getsource(value) + except (OSError, TypeError): + source = None -@Encodable.define.register(tuple) -def _encodable_tuple[T, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - args = typing.get_args(ty) - ctx = {} if ctx is None else ctx - - # handle namedtuples - origin = typing.get_origin(ty) - if origin is None: - return _encodable_object(ty, ctx) - # Handle empty tuple, or tuple with no args - if not args or args == ((),): - return _encodable_object(ty, ctx) + if source: + return typed_enc(module_code=textwrap.dedent(source)).model_dump() - # Create encoders for each element type - element_encoders = [Encodable.define(arg, ctx) for arg in args] + name = getattr(value, "__name__", None) + docstring = inspect.getdoc(value) + if name is None or docstring is None: + raise ValueError( + f"Cannot encode callable {value}: no source code and no __name__ or docstring" + ) - # Check if any element type is Image.Image - has_image = any(arg is Image.Image for arg in args) + try: + sig = inspect.signature(value) + sig_str = str(sig) + except (ValueError, TypeError): + sig_str = "(...)" - encoded_ty: type[typing.Any] = typing.cast( - type[typing.Any], - tuple[*(enc.enc for enc in element_encoders)], # type: ignore + stub_code = f'''def {name}{sig_str}: + """{docstring}""" + ... +''' + return typed_enc(module_code=stub_code).model_dump() + + return typing.Annotated[ + callable_type, + pydantic.PlainValidator(_validate), + pydantic.PlainSerializer(_serialize), + pydantic.WithJsonSchema( + _inline_refs(pydantic.TypeAdapter(typed_enc).json_schema()) + ), + ] + + +def _validate_tool( + value: ChatCompletionToolParam, info: pydantic.ValidationInfo +) -> Tool: + assert isinstance(info.context, Mapping), "Tool decoding requires context" + value = pydantic.TypeAdapter(ChatCompletionToolParam).validate_python(value) + try: + return info.context[value["function"]["name"]] + except KeyError as e: + raise NotImplementedError(f"Unknown tool: {value['function']['name']}") from e + + +def _serialize_tool(value: Tool) -> ChatCompletionToolParam: + fields: dict[str, Any] = { + name: TypeToPydanticType().evaluate(param.annotation) + for name, param in inspect.signature(value).parameters.items() + } + sig_model = pydantic.create_model( + "Params", + __config__={"extra": "forbid"}, + **fields, ) - - return typing.cast( - Encodable[T, U], - TupleEncodable(ty, encoded_ty, ctx, has_image, element_encoders), + response_format = litellm.utils.type_to_response_format_param(sig_model) + assert response_format is not None + assert value.__default__.__doc__ is not None + return pydantic.TypeAdapter(ChatCompletionToolParam).validate_python( + { + "type": "function", + "function": { + "name": value.__name__, + "description": textwrap.dedent(value.__default__.__doc__), + "parameters": response_format["json_schema"]["schema"], + "strict": True, + }, + } ) -@Encodable.define.register(list) -def _encodable_list[T, U]( - ty: type[list[T]], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - args = typing.get_args(ty) - ctx = {} if ctx is None else ctx - - # Handle unparameterized list (list without type args) - if not args: - return _encodable_object(ty, ctx) - - # Get the element type (first type argument) - element_ty = args[0] - element_encoder = Encodable.define(element_ty, ctx) - - # Check if element type is Image.Image - has_image = element_ty is Image.Image - - # Build the encoded type (list of encoded element type) - runtime-created, use Any - encoded_ty: type[typing.Any] = typing.cast( - type[typing.Any], - list[element_encoder.enc], # type: ignore - ) - - return typing.cast( - Encodable[T, U], ListEncodable(ty, encoded_ty, ctx, has_image, element_encoder) +@TypeToPydanticType.register(Tool) +def _pydantic_type_tool(ty: type[Tool]): + schema = _inline_refs(pydantic.TypeAdapter(ChatCompletionToolParam).json_schema()) + schema = _ensure_strict_json_schema(schema, path=(), root={}) + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_tool), + pydantic.PlainSerializer(_serialize_tool), + pydantic.WithJsonSchema(schema), + ] + + +def _validate_tool_call( + value: ChatCompletionMessageToolCall, + info: pydantic.ValidationInfo, +) -> DecodedToolCall: + if isinstance(value, dict): + value = OpenAIChatCompletionMessageToolCall.model_validate(value) + ctx = info.context or {} + assert value.function.name is not None + tool = ctx[value.function.name] + assert isinstance(tool, Tool) + sig = inspect.signature(tool) + decoded_args = {} + for name, raw_arg in json.loads(value.function.arguments).items(): + assert name in sig.parameters, ( + f"Unexpected argument {name} for tool {tool.__name__}" + ) + param = sig.parameters[name] + arg_enc: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter( + Encodable[param.annotation] # type: ignore[name-defined] + ) + decoded_args[name] = arg_enc.validate_python(raw_arg, context=ctx) + return DecodedToolCall( + tool=tool, + bound_args=sig.bind(**decoded_args), + id=value.id, + name=value.function.name, ) -@Encodable.define.register(Callable) -def _encodable_callable( - ty: type[Callable], ctx: Mapping[str, Any] | None -) -> Encodable[Callable, SynthesizedFunction]: - ctx = ctx or {} - - type_args = typing.get_args(ty) - - # Bare Callable without type args - allow encoding but disable decode - # this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type) - if not type_args: - assert ty is types.FunctionType, f"Callable must have type signatures {ty}" - typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type] - return CallableEncodable(ty, typed_enc, ctx) - - if len(type_args) < 2: - raise TypeError( - f"Callable type signature incomplete: {ty}. " - "Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]." +def _serialize_tool_call( + value: DecodedToolCall, info: pydantic.SerializationInfo +) -> dict: + ctx = info.context or {} + encoded_args = {} + for k, v in value.bound_args.arguments.items(): + v_enc: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter( + Encodable[nested_type(v).value] # type: ignore[misc] ) - - param_types, expected_return = type_args[0], type_args[-1] - - typed_enc = _create_typed_synthesized_function(ty) - - # Ellipsis means any params, skip param validation - expected_params: list[type] | None = None - if param_types is not ... and isinstance(param_types, list | tuple): - expected_params = list(param_types) - - return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return) + encoded_args[k] = v_enc.dump_python(v, mode="json", context=ctx) + return OpenAIChatCompletionMessageToolCall.model_validate( + { + "type": "function", + "id": value.id, + "function": { + "name": value.tool.__name__, + "arguments": json.dumps(encoded_args), + }, + } + ).model_dump(mode="json") + + +@TypeToPydanticType.register(DecodedToolCall) +def _pydantic_type_tool_call(ty: type[DecodedToolCall]): + # Use OpenAI's ChatCompletionMessageToolCall (has actual fields: id, function, + # type) rather than litellm's (empty dict with extra="allow"). + schema = _inline_refs(OpenAIChatCompletionMessageToolCall.model_json_schema()) + schema = _ensure_strict_json_schema(schema, path=(), root={}) + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_tool_call), + pydantic.PlainSerializer(_serialize_tool_call), + pydantic.WithJsonSchema(schema), + ] diff --git a/effectful/handlers/llm/evaluation.py b/effectful/handlers/llm/evaluation.py index 07348cc98..b4c4ecf67 100644 --- a/effectful/handlers/llm/evaluation.py +++ b/effectful/handlers/llm/evaluation.py @@ -392,7 +392,7 @@ def signature_to_ast(name: str, sig: inspect.Signature) -> ast.FunctionDef: except TypeError: returns = type_to_ast(typing.Any) - node = ast.FunctionDef( # type: ignore + node = ast.FunctionDef( name=name, args=ast.arguments( posonlyargs=[], @@ -415,6 +415,7 @@ def signature_to_ast(name: str, sig: inspect.Signature) -> ast.FunctionDef: ], decorator_list=[], returns=returns, + type_params=[], ) return ast.fix_missing_locations(node) diff --git a/effectful/handlers/llm/template.py b/effectful/handlers/llm/template.py index 2dfab38f7..f56d6fad7 100644 --- a/effectful/handlers/llm/template.py +++ b/effectful/handlers/llm/template.py @@ -1,15 +1,15 @@ import abc -import collections import functools import inspect +import re +import string import types import typing +from collections import ChainMap, OrderedDict from collections.abc import Callable, Mapping, MutableMapping -from dataclasses import dataclass from typing import Annotated, Any -from effectful.ops.semantics import handler -from effectful.ops.types import INSTANCE_OP_PREFIX, Annotation, Operation +from effectful.ops.types import Annotation, Operation class _IsRecursiveAnnotation(Annotation): @@ -90,13 +90,14 @@ def vacation() -> str: """ - def __init__( - self, signature: inspect.Signature, name: str, default: Callable[P, T] - ): + def __init__(self, default: Callable[P, T], name: str | None = None): if not default.__doc__: raise ValueError("Tools must have docstrings.") - signature = IsRecursive.infer_annotations(signature) - super().__init__(signature, name, default) + super().__init__(default, name=name) + + @property + def __signature__(self): + return IsRecursive.infer_annotations(super().__signature__) @classmethod def define(cls, *args, **kwargs) -> "Tool[P, T]": @@ -108,25 +109,6 @@ def define(cls, *args, **kwargs) -> "Tool[P, T]": return typing.cast("Tool[P, T]", super().define(*args, **kwargs)) -@dataclass -class _BoundInstance[T]: - instance: T - - -def _make_context_tool[T](name: str, value: T) -> Tool[[], T]: - """Create a synthetic read-only Tool for a lexical variable.""" - from effectful.internals.unification import nested_type - - def reader(): - return value - - reader.__name__ = name - reader.__doc__ = f"Read the value of lexical variable `{name}`" - reader.__annotations__ = {"return": nested_type(value).value} - - return Tool.define(reader) - - class Template[**P, T](Tool[P, T]): """A :class:`Template` is a function that is implemented by a large language model. @@ -186,7 +168,45 @@ class Template[**P, T](Tool[P, T]): """ - __context__: collections.ChainMap[str, Any] + __context__: ChainMap[str, Any] + __system_prompt__: str + + @classmethod + def _validate_prompt( + cls, + template: "Template", + context: ChainMap[str, Any], + ) -> None: + """Validate that all format string variables in the docstring + refer to names resolvable at call time. + + Each variable must be either a parameter in the signature + or a name captured in the lexical context. + + :raises TypeError: If any format string variable cannot be resolved. + """ + doc = template.__prompt_template__ + formatter = string.Formatter() + param_names = set(template.__signature__.parameters.keys()) + context_keys = set(context.keys()) + allowed_names = param_names | context_keys + + unresolved: list[str] = [] + for _, field_name, _, _ in formatter.parse(doc): + if field_name is None: + continue + # Extract root identifier from compound names like + match = re.match(r"^(\w+)", field_name) + root = match.group(1) if match else field_name + if root not in allowed_names: + unresolved.append(field_name) + + if unresolved: + raise TypeError( + f"Template '{template.__name__}' docstring references undefined " + f"variables {list(sorted(unresolved))} that are not in the signature " + f"{{{template.__signature__}}} or lexical scope." + ) @property def __prompt_template__(self) -> str: @@ -196,40 +216,17 @@ def __prompt_template__(self) -> str: @property def tools(self) -> Mapping[str, Tool]: """Operations and Templates available as tools. Auto-capture from lexical context.""" - result = {} - is_recursive = _is_recursive_signature(self.__signature__) + from effectful.handlers.llm.completions import _collect_tools - for name, obj in self.__context__.items(): - if obj is self and not is_recursive: - continue + result = _collect_tools(self.__context__) - # Collect tools in context - elif isinstance(obj, Tool): - result[name] = obj - - elif isinstance(obj, staticmethod) and isinstance(obj.__func__, Tool): - result[name] = obj.__func__ - - # Collect tools as methods on any bound instances - elif isinstance(obj, _BoundInstance): - for instance_name in obj.instance.__dir__(): - if instance_name.startswith(INSTANCE_OP_PREFIX): - continue - instance_obj = getattr(obj.instance, instance_name) - if isinstance(instance_obj, Tool): - result[instance_name] = instance_obj - - # Make tools for lexical variables - elif not ( - name.startswith("__") - or isinstance(obj, Operation) - or inspect.isclass(obj) - or inspect.isbuiltin(obj) - or inspect.ismodule(obj) - or inspect.isroutine(obj) - or inspect.isabstract(obj) - ): - result[name] = _make_context_tool(name, obj) + # We remove the template itself from the tool map unless it is explicitly + # marked as recursive (see test_template_method, test_template_method_nested_class). + if not _is_recursive_signature(self.__signature__): + result = dict(result) # copy to allow mutation + for name, tool in tuple(result.items()): + if tool is self: + del result[name] return result @@ -241,8 +238,18 @@ def __get__[S](self, instance: S | None, owner: type[S] | None = None): result = super().__get__(instance, owner) self_param_name = list(self.__signature__.parameters.keys())[0] - self_context = {self_param_name: _BoundInstance(instance)} - result.__context__ = self.__context__.new_child(self_context) + result.__context__ = self.__context__.new_child({self_param_name: instance}) + if isinstance(instance, Agent): + assert isinstance(result, Template) and not hasattr(result, "__history__") + result.__history__ = instance.__history__ # type: ignore[attr-defined] + result.__system_prompt__ = "\n\n".join( + part + for part in ( + getattr(result, "__system_prompt__", ""), + instance.__system_prompt__, + ) + if part + ) return result @classmethod @@ -263,35 +270,56 @@ def define[**Q, V]( frame = frame.f_back assert frame is not None - # Check if we're in a class definition by looking for __qualname__ + # Skip class body frames: in Python, class bodies are not lexical + # scopes for methods, so their locals should not be captured. qualname = frame.f_locals.get("__qualname__") - n_frames = 1 if qualname is not None: - name_components = qualname.split(".") - for name in reversed(name_components): + for name in reversed(qualname.split(".")): if name == "": break - n_frames += 1 - - contexts = [] - for offset in range(n_frames): - assert frame is not None - locals_proxy: types.MappingProxyType[str, Any] = types.MappingProxyType( - frame.f_locals - ) - globals_proxy: types.MappingProxyType[str, Any] = types.MappingProxyType( - frame.f_globals - ) - contexts.append(locals_proxy) - frame = frame.f_back + assert frame is not None + frame = frame.f_back + # Use the qualname of the decorated function to identify which + # frames are *lexical* enclosers (as opposed to dynamic callers). + # A segment preceding "" in the qualname is an enclosing + # function; everything else (class names, the function itself) is not. + assert frame is not None + _fn = default + if isinstance(_fn, staticmethod | classmethod): + _fn = _fn.__func__ + parts = _fn.__qualname__.split(".") + enclosing_fns = [ + parts[i] for i in range(len(parts) - 1) if parts[i + 1] == "" + ] + enclosing_fns.reverse() # innermost first for frame walking + + globals_proxy: types.MappingProxyType[str, Any] = types.MappingProxyType( + frame.f_globals + ) + contexts: list[types.MappingProxyType[str, Any]] = [] + for fn_name in enclosing_fns: + while frame is not None and frame.f_locals is not frame.f_globals: + if frame.f_code.co_name == fn_name: + contexts.append(types.MappingProxyType(frame.f_locals)) + frame = frame.f_back + break + frame = frame.f_back contexts.append(globals_proxy) - context: collections.ChainMap[str, Any] = collections.ChainMap( + context: ChainMap[str, Any] = ChainMap( *typing.cast(list[MutableMapping[str, Any]], contexts) ) - op = super().define(default, *args, **kwargs) op.__context__ = context # type: ignore[attr-defined] + mod = inspect.getmodule(_fn) + op.__system_prompt__ = inspect.getdoc(mod) if mod is not None else "" # type: ignore[attr-defined] + # Keep validation on original define-time callables, but skip the bound wrapper path. + # to avoid dropping `self` from the signature and falsely rejecting valid prompt fields like `{self.name}`. + is_bound_wrapper = ( + isinstance(default, types.MethodType) and default.__self__ is not None + ) + if not isinstance(op, staticmethod | classmethod) and not is_bound_wrapper: + cls._validate_prompt(typing.cast(Template, op), context) return typing.cast(Template[Q, V], op) @@ -333,25 +361,18 @@ def send(self, user_input: str) -> str: """ - __history__: collections.OrderedDict[str, Any] + __history__: OrderedDict[str, Mapping[str, Any]] + __system_prompt__: str def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - prop = functools.cached_property(lambda _: collections.OrderedDict()) - prop.__set_name__(cls, "__history__") - cls.__history__ = prop - - for name in list(cls.__dict__): - attr = cls.__dict__[name] - if not isinstance(attr, Template): - continue - _template = attr - - @functools.wraps(_template) - def wrapper(self, *args, _t=_template, **kwargs): - from effectful.handlers.llm.completions import get_message_sequence - - with handler({get_message_sequence: lambda: self.__history__}): - return _t(self, *args, **kwargs) - - setattr(cls, name, wrapper) + if not hasattr(cls, "__history__"): + prop = functools.cached_property(lambda _: OrderedDict()) + prop.__set_name__(cls, "__history__") + cls.__history__ = prop + if not hasattr(cls, "__system_prompt__"): + sp = functools.cached_property( + lambda self: inspect.getdoc(type(self)) or "" + ) + sp.__set_name__(cls, "__system_prompt__") + cls.__system_prompt__ = sp diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro.py index 74010de41..f0369d379 100644 --- a/effectful/handlers/numpyro.py +++ b/effectful/handlers/numpyro.py @@ -1,4 +1,5 @@ try: + import numpyro import numpyro.distributions as dist except ImportError: raise ImportError("Numpyro is required to use effectful.handlers.numpyro") @@ -332,6 +333,13 @@ def variance(self) -> jax.Array: except NotImplementedError: raise RuntimeError(f"variance is not implemented for {type(self).__name__}") + @property + @defop + def support(self) -> numpyro.distributions.constraints.Constraint: + if not self._is_eager: + raise NotHandled + return self._pos_base_dist.support + @defop def enumerate_support(self, expand=True) -> jax.Array: if not self._is_eager: diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 71d6583f2..77f5c9613 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -104,6 +104,86 @@ class Box[T]: value: T +class TypeEvaluator(abc.ABC): + """ + Abstract base class for evaluating type expressions. + + This class defines the interface for evaluating type expressions, which may + involve resolving type variables, computing canonical forms of types, or + performing other transformations. Subclasses should implement the evaluate + method to provide specific evaluation logic. + + The TypeEvaluator can be used in contexts where type expressions need to be + processed or normalized before unification or other type operations. + """ + + @functools.singledispatchmethod + def evaluate(self, typ) -> TypeExpressions: + """ + Normalize generic types + """ + raise TypeError(f"Cannot traverse type {typ}.") + + @evaluate.register + def _(self, typ: TypeConstant | TypeVariable): + return typ + + @evaluate.register + def _(self, typ: GenericAlias): + origin, args = typing.get_origin(typ), typing.get_args(typ) + return origin[self.evaluate(args)] # type: ignore[index] + + @evaluate.register + def _(self, typ: UnionType): + ctyp = self.evaluate(typing.get_args(typ)[0]) + for arg in typing.get_args(typ)[1:]: + ctyp = ctyp | self.evaluate(arg) # type: ignore + return ctyp + + @evaluate.register + def _(self, typ: typing._AnnotatedAlias): # type: ignore + return typing.Annotated[ + self.evaluate(typing.get_args(typ)[0]), + typ.__metadata__, + ] + + @evaluate.register + def _(self, typ: typing._LiteralGenericAlias): # type: ignore + return typ + + @evaluate.register + def _(self, typ: typing.ParamSpecArgs | typing.ParamSpecKwargs): + return typ + + @evaluate.register + def _(self, typ: typing._SpecialGenericAlias): # type: ignore + assert not typing.get_args(typ), "Should not have type arguments" + return typ + + @evaluate.register + def _(self, typ: typing._ConcatenateGenericAlias): # type: ignore + return typing.Concatenate[self.evaluate(typing.get_args(typ))] + + @evaluate.register + def _(self, typ: list | tuple): + return type(typ)(self.evaluate(item) for item in typ) + + @evaluate.register + def _(self, typ: typing.NewType): + return typing.NewType(typ.__name__, self.evaluate(typ.__supertype__)) # type: ignore[attr-defined,unused-ignore] + + @evaluate.register + def _(self, typ: typing.TypeAliasType): + return self.evaluate(typ.__value__) + + @evaluate.register + def _(self, typ: typing.ForwardRef): + if typ.__forward_value__ is not None: + return self.evaluate(typ.__forward_value__) + else: + return typ + + @typing.overload def unify( typ: inspect.Signature, diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7475dc71..f7678fd24 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -8,7 +8,7 @@ from collections.abc import Callable from typing import Any -from effectful.ops.syntax import _CustomSingleDispatchCallable, defop +from effectful.ops.syntax import _CustomSingleDispatchCallable, defdata, defop from effectful.ops.types import ( Expr, Interpretation, @@ -364,6 +364,7 @@ def _update_fvs(op, *args, **kwargs): assert isinstance(bound_var, Operation) if bound_var in _fvs: _fvs.remove(bound_var) + return defdata(op, *args, **kwargs) with interpreter({apply: _update_fvs}): evaluate(term) diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 43772b4e1..40c1f4af5 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -74,19 +74,30 @@ class Operation[**Q, V]: """ - __signature__: inspect.Signature __name__: str __default__: Callable[Q, V] __apply__: typing.ClassVar["Operation"] - def __init__( - self, signature: inspect.Signature, name: str, default: Callable[Q, V] - ): + def __init__(self, default: Callable[Q, V], name: str | None = None): functools.update_wrapper(self, default) - - self.__signature__ = signature - self.__name__ = name self.__default__ = default + self.__name__ = name or default.__name__ + + @property + def __signature__(self): + # Resolve forward references (e.g. -> "MyClass") using the + # default function's __globals__. This handles module-level + # forward refs; local forward refs will raise NameError. + # Python 3.14's annotationlib.get_annotations(format=FORWARDREF) + # could resolve local refs too via PEP 649 __annotate__ functions. + annots = typing.get_type_hints(self.__default__, include_extras=True) + sig = inspect.signature(self.__default__) + updated_params = [ + p.replace(annotation=annots[p.name]) if p.name in annots else p + for p in sig.parameters.values() + ] + updated_ret = annots.get("return", sig.return_annotation) + return sig.replace(parameters=updated_params, return_annotation=updated_ret) def __eq__(self, other): if not isinstance(other, Operation): @@ -267,8 +278,7 @@ def func(*args, **kwargs): op = cls.define(func, name=name) else: - name = name or t.__name__ - op = cls(inspect.signature(t), name, t) # type: ignore[arg-type] + op = cls(t, name=name) # type: ignore[arg-type] return op # type: ignore[return-value] @@ -441,7 +451,9 @@ def __str__(self): def __set_name__[T](self, owner: type[T], name: str) -> None: if not issubclass(owner, Term): assert not hasattr(self, "_name_on_instance"), "should only be called once" - self._name_on_instance: str = f"{INSTANCE_OP_PREFIX}_{name}" + self._name_on_instance: str = ( + f"{INSTANCE_OP_PREFIX}_{owner.__name__}_{name}" + ) def __get__[T](self, instance: T | None, owner: type[T] | None = None): if hasattr(instance, "__dict__") and hasattr(self, "_name_on_instance"): From 136b3a8dde68b27ffb5421d3521afa9989504624 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 25 Apr 2026 12:45:37 -0400 Subject: [PATCH 05/11] no agent test --- tests/test_handlers_llm_agent.py | 315 ------------------------------- 1 file changed, 315 deletions(-) delete mode 100644 tests/test_handlers_llm_agent.py diff --git a/tests/test_handlers_llm_agent.py b/tests/test_handlers_llm_agent.py deleted file mode 100644 index 38c8aacd1..000000000 --- a/tests/test_handlers_llm_agent.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Tests for Agent mixin message sequence semantics.""" - -import collections -import dataclasses - -from litellm import ModelResponse - -from effectful.handlers.llm import Agent, Template, Tool -from effectful.handlers.llm.completions import ( - LiteLLMProvider, - RetryLLMHandler, - completion, -) -from effectful.ops.semantics import handler -from effectful.ops.syntax import ObjectInterpretation, implements -from effectful.ops.types import NotHandled - -# --------------------------------------------------------------------------- -# Helpers (same pattern as test_handlers_llm_provider.py) -# --------------------------------------------------------------------------- - - -def make_text_response(content: str) -> ModelResponse: - return ModelResponse( - id="test", - choices=[ - { - "index": 0, - "message": {"role": "assistant", "content": content}, - "finish_reason": "stop", - } - ], - model="test-model", - ) - - -def make_tool_call_response( - tool_name: str, tool_args: str, tool_call_id: str = "call_1" -) -> ModelResponse: - return ModelResponse( - id="test", - choices=[ - { - "index": 0, - "message": { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": tool_call_id, - "type": "function", - "function": {"name": tool_name, "arguments": tool_args}, - } - ], - }, - "finish_reason": "tool_calls", - } - ], - model="test-model", - ) - - -class MockCompletionHandler(ObjectInterpretation): - """Returns pre-configured responses and captures messages sent to the LLM.""" - - def __init__(self, responses: list[ModelResponse]): - self.responses = responses - self.call_count = 0 - self.received_messages: list[list] = [] - - @implements(completion) - def _completion(self, model, messages=None, **kwargs): - self.received_messages.append(list(messages) if messages else []) - response = self.responses[min(self.call_count, len(self.responses) - 1)] - self.call_count += 1 - return response - - -# --------------------------------------------------------------------------- -# Agent subclass used by most tests -# --------------------------------------------------------------------------- - - -@dataclasses.dataclass -class ChatBot(Agent): - """Simple chat agent for testing history accumulation.""" - - bot_name: str = dataclasses.field(default="ChatBot") - - @Template.define - def send(self, user_input: str) -> str: - """A friendly bot named {self.bot_name}. User writes: {user_input}""" - raise NotHandled - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -class TestAgentHistoryAccumulation: - """History accumulates across sequential calls on the same instance.""" - - def test_second_call_sees_prior_messages(self): - mock = MockCompletionHandler( - [make_text_response("hi"), make_text_response("good")] - ) - bot = ChatBot() - - with handler(LiteLLMProvider()), handler(mock): - bot.send("hello") - bot.send("how are you") - - # First call: system + user → 2 messages - assert len(mock.received_messages[0]) == 2 - - # Second call: previous system + user + assistant, PLUS new system + user → 5 - assert len(mock.received_messages[1]) > len(mock.received_messages[0]) - - # Verify roles in second call - roles = [m["role"] for m in mock.received_messages[1]] - assert roles.count("assistant") >= 1 - assert roles.count("user") >= 2 - assert roles.count("system") >= 2 - - def test_history_contains_all_messages_after_two_calls(self): - mock = MockCompletionHandler( - [make_text_response("r1"), make_text_response("r2")] - ) - bot = ChatBot() - - with handler(LiteLLMProvider()), handler(mock): - bot.send("a") - bot.send("b") - - # After two complete calls the history should have: - # call 1: system, user, assistant (3) - # call 2: system, user, assistant (3) - assert len(bot.__history__) == 6 - - def test_message_ids_are_unique(self): - mock = MockCompletionHandler( - [make_text_response("r1"), make_text_response("r2")] - ) - bot = ChatBot() - - with handler(LiteLLMProvider()), handler(mock): - bot.send("a") - bot.send("b") - - ids = list(bot.__history__.keys()) - assert len(ids) == len(set(ids)), "message IDs must be unique" - - -class TestAgentIsolation: - """Each agent instance has independent history; non-agent templates are unaffected.""" - - def test_two_agents_have_independent_histories(self): - mock = MockCompletionHandler( - [ - make_text_response("from bot1"), - make_text_response("from bot2"), - ] - ) - bot1 = ChatBot() - bot2 = ChatBot() - - with handler(LiteLLMProvider()), handler(mock): - bot1.send("msg for bot1") - bot2.send("msg for bot2") - - # bot2's call should NOT contain bot1's messages - assert len(mock.received_messages[1]) == 2 # system + user only - - # Each bot has its own history - assert len(bot1.__history__) == 3 # system, user, assistant - assert len(bot2.__history__) == 3 - - # Histories share no message IDs - assert set(bot1.__history__.keys()).isdisjoint(set(bot2.__history__.keys())) - - def test_non_agent_template_gets_fresh_sequence(self): - @Template.define - def standalone(topic: str) -> str: - """Write about {topic}.""" - raise NotHandled - - mock = MockCompletionHandler( - [ - make_text_response("agent reply"), - make_text_response("standalone reply"), - make_text_response("agent reply 2"), - ] - ) - bot = ChatBot() - - with handler(LiteLLMProvider()), handler(mock): - bot.send("hello") - standalone("fish") - bot.send("bye") - - # standalone (call index 1) should see only system + user (fresh sequence) - assert len(mock.received_messages[1]) == 2 - - # bot's third call (call index 2) should see its accumulated history - # but NOT the standalone messages - assert len(mock.received_messages[2]) == 5 # 3 from first call + 2 new - - -class TestAgentCachedProperty: - """__history__ is lazily created per instance without requiring __init__.""" - - def test_no_init_required(self): - class MinimalAgent(Agent): - @Template.define - def greet(self, name: str) -> str: - """Hello {name}.""" - raise NotHandled - - agent = MinimalAgent() - # Should be an OrderedDict, created on first access - assert isinstance(agent.__history__, collections.OrderedDict) - assert len(agent.__history__) == 0 - - def test_subclass_with_own_init(self): - class CustomAgent(Agent): - def __init__(self, name: str): - self.name = name - - @Template.define - def greet(self) -> str: - """Say hello.""" - raise NotHandled - - agent = CustomAgent("Alice") - assert agent.name == "Alice" - assert isinstance(agent.__history__, collections.OrderedDict) - - def test_history_is_per_instance(self): - a = ChatBot() - b = ChatBot() - a.__history__["fake"] = {"id": "fake", "role": "user", "content": "x"} - assert "fake" not in b.__history__ - - -class TestAgentWithToolCalls: - """Agent methods that trigger tool calls maintain correct history.""" - - def test_tool_call_results_appear_in_history(self): - @Tool.define - def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - class MathAgent(Agent): - @Template.define - def compute(self, question: str) -> str: - """Answer: {question}""" - raise NotHandled - - mock = MockCompletionHandler( - [ - make_tool_call_response("add", '{"a": 2, "b": 3}'), - make_text_response("The answer is 5"), - ] - ) - agent = MathAgent() - - with handler(LiteLLMProvider()), handler(mock): - result = agent.compute("what is 2+3?") - - assert result == "The answer is 5" - - # History should contain: system, user, assistant (tool_call), - # tool (result), assistant (final) - roles = [m["role"] for m in agent.__history__.values()] - assert "tool" in roles - assert roles.count("assistant") == 2 - - -class TestAgentWithRetryHandler: - """RetryLLMHandler composes correctly with Agent history.""" - - def test_failed_retries_dont_pollute_history(self): - mock = MockCompletionHandler( - [ - # First attempt: invalid result for int - make_text_response('{"value": "not_an_int"}'), - # Retry: valid - make_text_response('{"value": 42}'), - ] - ) - - class NumberAgent(Agent): - @Template.define - def pick_number(self) -> int: - """Pick a number.""" - raise NotHandled - - agent = NumberAgent() - - with ( - handler(LiteLLMProvider()), - handler(RetryLLMHandler(num_retries=3)), - handler(mock), - ): - result = agent.pick_number() - - assert result == 42 - - # The malformed assistant message and error feedback from the retry - # should NOT appear in the agent's history. Only the final successful - # assistant message should be there. - roles = [m["role"] for m in agent.__history__.values()] - assert roles == ["system", "user", "assistant"] From 52dc30eb83e705c09a31f05aa56999cae74a982b Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 25 Apr 2026 14:25:54 -0400 Subject: [PATCH 06/11] updates --- docs/source/llm_examples/async_concurrency.py | 9 +--- docs/source/llm_examples/batch_translate.py | 19 ++++---- docs/source/llm_examples/chat_memory.py | 9 +--- docs/source/llm_examples/chat_search.py | 21 +++++---- docs/source/llm_examples/flight_booking.py | 22 +++++---- docs/source/llm_examples/guardrails.py | 38 ++++++++------- .../llm_examples/hanoi_solver_iterative.py | 16 +++---- .../llm_examples/hanoi_solver_recursive.py | 16 +++---- docs/source/llm_examples/hitl.py | 22 +++++---- docs/source/llm_examples/majority_vote.py | 9 +--- docs/source/llm_examples/map_reduce.py | 46 ++++++++++--------- docs/source/llm_examples/multi_agent.py | 22 +++++---- docs/source/llm_examples/rag.py | 23 ++++++---- docs/source/llm_examples/research_agent.py | 9 +--- docs/source/llm_examples/supervisor.py | 21 +++++---- docs/source/llm_examples/tao_agent.py | 15 +++--- docs/source/llm_examples/text2sql.py | 22 +++++---- docs/source/llm_examples/thinking.py | 22 +++++---- 18 files changed, 181 insertions(+), 180 deletions(-) diff --git a/docs/source/llm_examples/async_concurrency.py b/docs/source/llm_examples/async_concurrency.py index 32ec20c77..b47170610 100644 --- a/docs/source/llm_examples/async_concurrency.py +++ b/docs/source/llm_examples/async_concurrency.py @@ -52,17 +52,10 @@ async def main(provider: LiteLLMProvider): parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) asyncio.run(main(provider)) diff --git a/docs/source/llm_examples/batch_translate.py b/docs/source/llm_examples/batch_translate.py index bcdb343d3..66b4999f8 100644 --- a/docs/source/llm_examples/batch_translate.py +++ b/docs/source/llm_examples/batch_translate.py @@ -7,8 +7,10 @@ import argparse import os +from tenacity import stop_after_attempt + from effectful.handlers.llm import Template -from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.handlers.llm.evaluation import RestrictedEvalProvider from effectful.ops.semantics import handler from effectful.ops.types import NotHandled @@ -38,7 +40,7 @@ def translate(target_language: str, instructions: str = "") -> Template[[str], s parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -55,16 +57,13 @@ def translate(target_language: str, instructions: str = "") -> Template[[str], s ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RestrictedEvalProvider()): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + handler(RestrictedEvalProvider()), + ): translator = translate( target_language="french", instructions="Use formal language." ) diff --git a/docs/source/llm_examples/chat_memory.py b/docs/source/llm_examples/chat_memory.py index 42c8b46ac..b926cdff5 100644 --- a/docs/source/llm_examples/chat_memory.py +++ b/docs/source/llm_examples/chat_memory.py @@ -109,18 +109,11 @@ def chat(self, user_input: str): parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - agent = ChatAgent() provider = LiteLLMProvider(model=args.model) diff --git a/docs/source/llm_examples/chat_search.py b/docs/source/llm_examples/chat_search.py index 8dcdd1691..a2d7cecd5 100644 --- a/docs/source/llm_examples/chat_search.py +++ b/docs/source/llm_examples/chat_search.py @@ -4,6 +4,7 @@ import urllib.parse import requests +from tenacity import stop_after_attempt from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler @@ -77,7 +78,7 @@ def send(self, user_input: str) -> str: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -91,19 +92,21 @@ def send(self, user_input: str) -> str: action="store_true", help="Run in interactive mode, allowing multiple back-and-forth messages", ) + parser.add_argument( + "--num-retries", + type=int, + default=4, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - chatbot = ChatBot(bot_name=args.name) provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=3)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): if args.interactive: while True: print(chatbot.send(input("You: "))) diff --git a/docs/source/llm_examples/flight_booking.py b/docs/source/llm_examples/flight_booking.py index 35cdec35d..d4de2e97a 100644 --- a/docs/source/llm_examples/flight_booking.py +++ b/docs/source/llm_examples/flight_booking.py @@ -15,6 +15,8 @@ import os from typing import Literal +from tenacity import stop_after_attempt + from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -227,7 +229,7 @@ def book_flight( parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -235,18 +237,20 @@ def book_flight( action="store_true", help="Run in interactive mode with user prompts", ) + parser.add_argument( + "--num-retries", + type=int, + default=4, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=5)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): book_flight( origin=Airport.SFO, destination=Airport.ANC, diff --git a/docs/source/llm_examples/guardrails.py b/docs/source/llm_examples/guardrails.py index 304818ffb..6ac800572 100644 --- a/docs/source/llm_examples/guardrails.py +++ b/docs/source/llm_examples/guardrails.py @@ -8,6 +8,8 @@ import argparse import os +from tenacity import stop_after_attempt + from effectful.handlers.llm import Template from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -26,14 +28,6 @@ def travel_query(user_query: str) -> str: raise NotHandled -@Template.define -def is_safe_query(user_query: str) -> bool: - """ - Determine whether the user's query is purely related to travel advice: {user_query} - """ - raise NotHandled - - # --------------------------------------------------------------------------- # Guarded agent # --------------------------------------------------------------------------- @@ -41,6 +35,14 @@ def is_safe_query(user_query: str) -> bool: def answer_travel_query(user_query: str) -> str: """Only answer travel-related queries; reject everything else.""" + + @Template.define + def is_safe_query(user_query: str) -> bool: + """ + Determine whether the user's query is purely related to travel advice: {user_query} + """ + raise NotHandled + if is_safe_query(user_query): return travel_query(user_query) else: @@ -56,19 +58,21 @@ def answer_travel_query(user_query: str) -> str: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) + parser.add_argument( + "--num-retries", + type=int, + default=4, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=5)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): print(answer_travel_query("What are great places to check out in NYC?")) print(answer_travel_query("Should I buy apple stocks?")) diff --git a/docs/source/llm_examples/hanoi_solver_iterative.py b/docs/source/llm_examples/hanoi_solver_iterative.py index 8733c64ac..0a5ecdbab 100644 --- a/docs/source/llm_examples/hanoi_solver_iterative.py +++ b/docs/source/llm_examples/hanoi_solver_iterative.py @@ -14,6 +14,8 @@ import os from dataclasses import dataclass, field +from tenacity import stop_after_attempt + from effectful.handlers.llm import Template, Tool from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -173,7 +175,7 @@ def solve_hanoi(state: GameState, max_steps: int = 30): parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -196,14 +198,10 @@ def solve_hanoi(state: GameState, max_steps: int = 30): ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=args.num_retries)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): solve_hanoi(GameState(size=args.game_size), max_steps=args.max_steps) diff --git a/docs/source/llm_examples/hanoi_solver_recursive.py b/docs/source/llm_examples/hanoi_solver_recursive.py index b5da9107a..b84387729 100644 --- a/docs/source/llm_examples/hanoi_solver_recursive.py +++ b/docs/source/llm_examples/hanoi_solver_recursive.py @@ -29,6 +29,8 @@ import typing from dataclasses import dataclass, field +from tenacity import stop_after_attempt + from effectful.handlers.llm import Template from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.handlers.llm.template import IsRecursive @@ -161,7 +163,7 @@ def validate_solution(size: int, steps: list[Step]) -> bool: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -178,16 +180,12 @@ def validate_solution(size: int, steps: list[Step]) -> bool: ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=args.num_retries)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): n = args.game_size print(f"Solving Tower of Hanoi with {n} disks...") steps = solve(n_disks=n, source=0, target=n - 1, auxiliary=1) diff --git a/docs/source/llm_examples/hitl.py b/docs/source/llm_examples/hitl.py index 5b2ebe17c..540fc27c6 100644 --- a/docs/source/llm_examples/hitl.py +++ b/docs/source/llm_examples/hitl.py @@ -13,6 +13,8 @@ import enum import os +from tenacity import stop_after_attempt + from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -130,7 +132,7 @@ def run_with_approval( parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -144,15 +146,14 @@ def run_with_approval( default=5, help="Maximum number of action steps", ) + parser.add_argument( + "--num-retries", + type=int, + default=3, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) task = ( @@ -161,7 +162,10 @@ def run_with_approval( "restaurant suggestions, and schedule a meeting to finalize plans." ) - with handler(provider), handler(RetryLLMHandler(num_retries=3)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): print(f"Task: {task}\n") log = run_with_approval( task, diff --git a/docs/source/llm_examples/majority_vote.py b/docs/source/llm_examples/majority_vote.py index 1ad8c6296..25696ddcb 100644 --- a/docs/source/llm_examples/majority_vote.py +++ b/docs/source/llm_examples/majority_vote.py @@ -59,7 +59,7 @@ def majority_vote[Q]( parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -73,13 +73,6 @@ def majority_vote[Q]( ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) with handler(provider): answer, count = majority_vote(yes_or_no, args.question, voters=args.num_voters) diff --git a/docs/source/llm_examples/map_reduce.py b/docs/source/llm_examples/map_reduce.py index f7e00ac1d..70a79c595 100644 --- a/docs/source/llm_examples/map_reduce.py +++ b/docs/source/llm_examples/map_reduce.py @@ -9,10 +9,13 @@ import argparse import asyncio +import collections.abc import dataclasses import functools import os +from tenacity import stop_after_attempt + from effectful.handlers.llm import Template from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -53,13 +56,16 @@ def evaluate_resume(resume: str, job_description: str) -> Evaluation: @Template.define -def summarize_evaluations(job_description: str, evaluations_text: str) -> str: +def summarize_evaluations( + job_description: str, + evaluations: collections.abc.Sequence[Evaluation], +) -> str: """You are a hiring manager summarizing candidate evaluations. Job description: {job_description} Individual evaluations: - {evaluations_text} + {evaluations} Provide a brief summary: rank the candidates from best to worst, highlight the top candidate, and note any concerns. @@ -103,7 +109,11 @@ async def map_reduce_evaluate( # Map: evaluate each resume concurrently evaluate = functools.partial( asyncio.to_thread, - handler(provider)(handler(RetryLLMHandler(num_retries=3))(evaluate_resume)), + handler(provider)( + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries)))( + evaluate_resume + ) + ), ) evaluations: list[Evaluation] = list( await asyncio.gather(*(evaluate(resume, job_description) for resume in resumes)) @@ -116,16 +126,11 @@ async def map_reduce_evaluate( print(f" - {ev.weaknesses}") # Reduce: summarize all evaluations - evaluations_text = "\n\n".join( - f"Candidate: {ev.name}\n" - f"Score: {ev.score}/10\n" - f"Qualified: {ev.qualified}\n" - f"Strengths: {ev.strengths}\n" - f"Weaknesses: {ev.weaknesses}" - for ev in evaluations - ) - with handler(provider), handler(RetryLLMHandler(num_retries=3)): - return summarize_evaluations(job_description, evaluations_text) + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): + return summarize_evaluations(job_description, evaluations) # --------------------------------------------------------------------------- @@ -137,18 +142,17 @@ async def map_reduce_evaluate( parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) + parser.add_argument( + "--num-retries", + type=int, + default=3, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) print(f"Evaluating {len(RESUMES)} resumes for: {JOB_DESCRIPTION}\n") diff --git a/docs/source/llm_examples/multi_agent.py b/docs/source/llm_examples/multi_agent.py index 448cf7a4f..0389c6c87 100644 --- a/docs/source/llm_examples/multi_agent.py +++ b/docs/source/llm_examples/multi_agent.py @@ -12,6 +12,8 @@ import enum import os +from tenacity import stop_after_attempt + from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -133,7 +135,7 @@ def play_taboo( parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -142,15 +144,14 @@ def play_taboo( default=5, help="Maximum rounds per game", ) + parser.add_argument( + "--num-retries", + type=int, + default=3, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - games = [ ("piano", ["music", "keys", "instrument", "play"]), ("volcano", ["lava", "eruption", "mountain", "hot"]), @@ -158,7 +159,10 @@ def play_taboo( provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=3)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): for secret, taboo in games: print(f"\nGame: '{secret}' (taboo: {taboo})") play_taboo(secret, taboo, max_rounds=args.max_rounds) diff --git a/docs/source/llm_examples/rag.py b/docs/source/llm_examples/rag.py index 166f1dbb4..eca2b4507 100644 --- a/docs/source/llm_examples/rag.py +++ b/docs/source/llm_examples/rag.py @@ -14,6 +14,7 @@ import litellm import numpy as np +from tenacity import stop_after_attempt from effectful.handlers.llm import Template, Tool from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler @@ -153,24 +154,23 @@ def answer_question(question: str) -> str: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( "--embedding-model", type=str, - default="lm_studio/nomic-ai/nomic-embed-text-v1.5-GGUF", + default="lm_studio/text-embedding-embeddinggemma-300m-qat", help="Embedding model to use", ) + parser.add_argument( + "--num-retries", + type=int, + default=3, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - # Offline: build the index index = build_index(DOCUMENTS, embedding_model=args.embedding_model) @@ -186,7 +186,10 @@ def answer_question(question: str) -> str: provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=3)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): for question in questions: print(f"\nQ: {question}") answer = answer_question(question) diff --git a/docs/source/llm_examples/research_agent.py b/docs/source/llm_examples/research_agent.py index bc193db1b..308d2df23 100644 --- a/docs/source/llm_examples/research_agent.py +++ b/docs/source/llm_examples/research_agent.py @@ -119,7 +119,7 @@ def research_agent(question: str, max_attempts: int = 3) -> str: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -130,13 +130,6 @@ def research_agent(question: str, max_attempts: int = 3) -> str: ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) with handler(provider): diff --git a/docs/source/llm_examples/supervisor.py b/docs/source/llm_examples/supervisor.py index 1009d6a62..4c07a9f83 100644 --- a/docs/source/llm_examples/supervisor.py +++ b/docs/source/llm_examples/supervisor.py @@ -12,6 +12,7 @@ import urllib.parse import requests +from tenacity import stop_after_attempt from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler @@ -150,7 +151,7 @@ def supervised_research(question: str, max_retries: int = 3) -> str: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -159,18 +160,20 @@ def supervised_research(question: str, max_retries: int = 3) -> str: default=3, help="Maximum number of supervisor rejections before accepting", ) + parser.add_argument( + "--num-retries", + type=int, + default=3, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(RetryLLMHandler(num_retries=3)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): result = supervised_research( "What year was the Eiffel Tower completed and how tall is it?", max_retries=args.max_retries, diff --git a/docs/source/llm_examples/tao_agent.py b/docs/source/llm_examples/tao_agent.py index 8f9fbb0c1..2a8a14717 100644 --- a/docs/source/llm_examples/tao_agent.py +++ b/docs/source/llm_examples/tao_agent.py @@ -14,6 +14,7 @@ import urllib.parse import requests +from tenacity import stop_after_attempt from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import ( @@ -149,7 +150,7 @@ def _act(self, action: AgentAction, action_input: str) -> str: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -166,18 +167,14 @@ def _act(self, action: AgentAction, action_input: str) -> str: ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) agent = TAOAgent() - with handler(provider), handler(RetryLLMHandler(num_retries=args.num_retries)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): answer = agent.run( "How many tennis balls would fill an Olympic swimming pool?", max_steps=args.max_steps, diff --git a/docs/source/llm_examples/text2sql.py b/docs/source/llm_examples/text2sql.py index 5d511b5f8..a3e36f933 100644 --- a/docs/source/llm_examples/text2sql.py +++ b/docs/source/llm_examples/text2sql.py @@ -12,6 +12,8 @@ import sqlite3 import textwrap +from tenacity import stop_after_attempt + from effectful.handlers.llm import Template from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -138,18 +140,17 @@ def text_to_sql( parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) + parser.add_argument( + "--num-retries", + type=int, + default=3, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - conn = create_sample_db() provider = LiteLLMProvider(model=args.model) @@ -159,7 +160,10 @@ def text_to_sql( "How many employees were hired after 2021?", ] - with handler(provider), handler(RetryLLMHandler(num_retries=3)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): for question in questions: print(f"\nQ: {question}") try: diff --git a/docs/source/llm_examples/thinking.py b/docs/source/llm_examples/thinking.py index 101d5cb2e..058de61ef 100644 --- a/docs/source/llm_examples/thinking.py +++ b/docs/source/llm_examples/thinking.py @@ -10,6 +10,8 @@ import dataclasses import os +from tenacity import stop_after_attempt + from effectful.handlers.llm import Agent, Template from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.ops.semantics import handler @@ -68,7 +70,7 @@ def solve(self, problem: str, max_steps: int = 10) -> str: parser.add_argument( "--model", type=str, - default="lm_studio/zai-org/glm-4.7-flash", + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) parser.add_argument( @@ -86,15 +88,14 @@ def solve(self, problem: str, max_steps: int = 10) -> str: ), help="The problem to solve", ) + parser.add_argument( + "--num-retries", + type=int, + default=3, + help="Number of retries for malformed LLM output", + ) args = parser.parse_args() - if args.model.startswith("lm_studio/"): - assert os.environ.get("LM_STUDIO_API_BASE") - elif args.model.startswith("gpt-"): - assert os.environ.get("OPENAI_API_KEY") - elif args.model.startswith("claude-"): - assert os.environ.get("ANTHROPIC_API_KEY") - provider = LiteLLMProvider(model=args.model) problems = [ @@ -105,7 +106,10 @@ def solve(self, problem: str, max_steps: int = 10) -> str: ), ] - with handler(provider), handler(RetryLLMHandler(num_retries=3)): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): for problem in problems: thinker = Thinker() print(f"\nProblem: {problem}") From 2cbb0e0f3cbb1caa639834c94353bef860b74f5c Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 25 Apr 2026 15:10:01 -0400 Subject: [PATCH 07/11] notebook sections as scripts --- docs/source/llm_examples/decode_callable.py | 78 +++++++++++++ .../llm_examples/higher_order_function.py | 105 +++++++++++++++++ docs/source/llm_examples/image_input.py | 63 ++++++++++ docs/source/llm_examples/prompt_templates.py | 88 ++++++++++++++ docs/source/llm_examples/retry_tool_errors.py | 91 +++++++++++++++ docs/source/llm_examples/retry_validation.py | 109 ++++++++++++++++++ docs/source/llm_examples/structured_output.py | 82 +++++++++++++ .../llm_examples/template_composition.py | 76 ++++++++++++ docs/source/llm_examples/tool_calling.py | 65 +++++++++++ 9 files changed, 757 insertions(+) create mode 100644 docs/source/llm_examples/decode_callable.py create mode 100644 docs/source/llm_examples/higher_order_function.py create mode 100644 docs/source/llm_examples/image_input.py create mode 100644 docs/source/llm_examples/prompt_templates.py create mode 100644 docs/source/llm_examples/retry_tool_errors.py create mode 100644 docs/source/llm_examples/retry_validation.py create mode 100644 docs/source/llm_examples/structured_output.py create mode 100644 docs/source/llm_examples/template_composition.py create mode 100644 docs/source/llm_examples/tool_calling.py diff --git a/docs/source/llm_examples/decode_callable.py b/docs/source/llm_examples/decode_callable.py new file mode 100644 index 000000000..05a09212a --- /dev/null +++ b/docs/source/llm_examples/decode_callable.py @@ -0,0 +1,78 @@ +"""Decoding LLM responses into Python objects, including callables. + +Demonstrates: +- Primitive type decoding (``int``) from a template that returns a number +- Synthesizing a Python ``Callable`` from a template, executed via + ``UnsafeEvalProvider`` from ``effectful.handlers.llm.evaluation`` +- ``inspect.getsource`` on the synthesized function +""" + +import argparse +import inspect +import os +from collections.abc import Callable + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.handlers.llm.evaluation import UnsafeEvalProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Templates +# --------------------------------------------------------------------------- + + +@Template.define +def primes(first_digit: int) -> int: + """Give a prime number with {first_digit} as the first digit. Do not use any tools.""" + raise NotHandled + + +@Template.define +def count_char(char: str) -> Callable[[str], int]: + """Write a function which takes a string and counts the occurrances of '{char}'. Do not use any tools.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Decode LLM responses to Python objects (incl. callables)" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + parser.add_argument( + "--first-digit", + type=int, + default=6, + help="First digit of the prime to request", + ) + parser.add_argument( + "--char", + type=str, + default="a", + help="Character whose occurrences the synthesized function will count", + ) + args = parser.parse_args() + + provider = LiteLLMProvider(model=args.model) + + with handler(provider), handler(UnsafeEvalProvider()): + prime = primes(args.first_digit) + assert type(prime) is int + print(f"Prime starting with {args.first_digit}: {prime}") + + counter = count_char(args.char) + assert callable(counter) + print("\nGenerated function:") + print(inspect.getsource(counter)) + print(f'counter("banana") == {counter("banana")}') + print(f'counter("cherry") == {counter("cherry")}') diff --git a/docs/source/llm_examples/higher_order_function.py b/docs/source/llm_examples/higher_order_function.py new file mode 100644 index 000000000..da2410959 --- /dev/null +++ b/docs/source/llm_examples/higher_order_function.py @@ -0,0 +1,105 @@ +"""Generating higher-order functions that call other templates. + +Demonstrates: +- A template returning a ``Callable``, evaluated via ``UnsafeEvalProvider`` +- The synthesized function calling sub-templates (``write_chapter``, + ``judge_chapter``) at runtime +- ``RetryLLMHandler`` to recover from transient validation/runtime errors +- ``inspect.getsource`` on the generated function +""" + +import argparse +import inspect +import os +from collections.abc import Callable +from typing import Literal + +from tenacity import stop_after_attempt + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.handlers.llm.evaluation import UnsafeEvalProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Sub-templates the generated function may call +# --------------------------------------------------------------------------- + + +@Template.define +def write_chapter(chapter_number: int, chapter_name: str) -> str: + """Write a short story about {chapter_number}. Do not use any tools.""" + raise NotHandled + + +@Template.define +def judge_chapter(story_so_far: str, chapter_number: int) -> bool: + """Decide if the new chapter is coherent with the story so far. Do not use any tools.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Orchestrator template returning a callable +# --------------------------------------------------------------------------- + + +@Template.define +def write_multi_chapter_story(style: Literal["moral", "funny"]) -> Callable[[str], str]: + """Generate a function that writes a story in style: {style} about the given topic. + + If you raise an exception, handle it yourself. + The program can use helper functions defined elsewhere (DO NOT REDEFINE THEM): + - write_chapter(chapter_number: int, chapter_name: str) -> str + - judge_chapter(story_so_far: str, chapter_number: int) -> bool + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a higher-order function that calls sub-templates" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + parser.add_argument( + "--topic", type=str, default="a curious cat", help="Story topic" + ) + parser.add_argument( + "--style", + type=str, + choices=["moral", "funny"], + default="moral", + help="Story style", + ) + parser.add_argument( + "--num-retries", + type=int, + default=4, + help="Number of retries for malformed LLM output", + ) + args = parser.parse_args() + + provider = LiteLLMProvider(model=args.model) + + print("Sub-templates available to write_multi_chapter_story:") + print(list(write_multi_chapter_story.tools.keys())) + + with ( + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + handler(provider), + handler(UnsafeEvalProvider()), + ): + print(f"\n=== Generating story function (style={args.style}) ===") + story_fn = write_multi_chapter_story(args.style) + print(inspect.getsource(story_fn)) + print(f"\n=== Running generated function on {args.topic!r} ===") + print(story_fn(args.topic)) diff --git a/docs/source/llm_examples/image_input.py b/docs/source/llm_examples/image_input.py new file mode 100644 index 000000000..14375b294 --- /dev/null +++ b/docs/source/llm_examples/image_input.py @@ -0,0 +1,63 @@ +"""Passing PIL images directly to a template. + +Demonstrates: +- Templates accepting ``PIL.Image.Image`` arguments +- Inline base64 image data so the script is self-contained +""" + +import argparse +import base64 +import io +import os + +from PIL import Image + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Inline image (32x32 yellow smiley face) +# --------------------------------------------------------------------------- + +IMAGE_BASE64 = ( + "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAAhElEQVR4nO2W4QqA" + "MAiEVXr/VzYWDGoMdk7Cgrt/sUs/DqZTd3EplFU2JwATYAJMoOlAB4bq89s95+Mg" + "+gyAchsKAYplBBBA43hFhfxnUixDjdEUUL8hpr7R0KLdt9qElzcyiu8As+Kr8zQA" + "mgLavAl+kIzFZyCRxtsAmWb/voZvqRzgBE1sIDuVFX4eAAAAAElFTkSuQmCC" +) + + +# --------------------------------------------------------------------------- +# Template +# --------------------------------------------------------------------------- + + +@Template.define +def describe_image(image: Image.Image) -> str: + """Return a short description of the following image. + {image} + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Pass a PIL image to a template") + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use (must support image inputs)", + ) + args = parser.parse_args() + + image = Image.open(io.BytesIO(base64.b64decode(IMAGE_BASE64))) + + provider = LiteLLMProvider(model=args.model) + with handler(provider): + print(describe_image(image)) diff --git a/docs/source/llm_examples/prompt_templates.py b/docs/source/llm_examples/prompt_templates.py new file mode 100644 index 000000000..56369056a --- /dev/null +++ b/docs/source/llm_examples/prompt_templates.py @@ -0,0 +1,88 @@ +"""Basic prompt templates and deterministic caching. + +Demonstrates: +- ``@Template.define`` for declaring an LLM-backed function +- Non-determinism: calling the same template twice yields different results +- ``functools.cache`` to make a template call deterministic in-process +- ``LiteLLMProvider(caching=True)`` with ``litellm.cache`` for cross-process caching +""" + +import argparse +import functools +import os + +import litellm +from litellm.caching.caching import Cache + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Templates +# --------------------------------------------------------------------------- + + +@Template.define +def limerick(theme: str) -> str: + """Write a limerick on the theme of {theme}. Do not use any tools.""" + raise NotHandled + + +@functools.cache +@Template.define +def haiku(theme: str) -> str: + """Write a haiku on the theme of {theme}. Do not use any tools.""" + raise NotHandled + + +@Template.define +def haiku_no_cache(theme: str) -> str: + """Write a haiku on the theme of {theme}. Do not use any tools.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Basic prompt templates and deterministic caching" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + parser.add_argument( + "--theme", type=str, default="fish", help="Theme for the poem" + ) + args = parser.parse_args() + + provider = LiteLLMProvider(model=args.model) + + print("=== Non-deterministic limerick (two independent calls) ===") + with handler(provider): + print(limerick(args.theme)) + print("-" * 40) + print(limerick(args.theme)) + + print("\n=== functools.cache: same result on second call ===") + with handler(provider): + print(haiku(args.theme)) + print("-" * 40) + print(haiku(args.theme)) + + print("\n=== LiteLLMProvider(caching=True): backed by litellm.cache ===") + litellm.cache = Cache() + provider_cached = LiteLLMProvider(model=args.model, caching=True) + try: + with handler(provider_cached): + print(haiku_no_cache(args.theme)) + print("-" * 40) + print(haiku_no_cache(args.theme)) + finally: + litellm.cache = None diff --git a/docs/source/llm_examples/retry_tool_errors.py b/docs/source/llm_examples/retry_tool_errors.py new file mode 100644 index 000000000..7ce0b7b11 --- /dev/null +++ b/docs/source/llm_examples/retry_tool_errors.py @@ -0,0 +1,91 @@ +"""Retrying tool execution failures. + +Demonstrates: +- ``RetryLLMHandler`` surfacing tool exceptions back to the LLM as tool messages +- A flaky tool (``unstable_service``) that succeeds only after multiple attempts +- The contrast between an unhandled failure and a retry-handled success +""" + +import argparse +import os + +from tenacity import stop_after_attempt + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Flaky tool +# --------------------------------------------------------------------------- + +call_count = 0 +REQUIRED_RETRIES = 3 + + +@Tool.define +def unstable_service() -> str: + """Fetch data from an unstable external service. May require retries.""" + global call_count + call_count += 1 + if call_count < REQUIRED_RETRIES: + raise ConnectionError( + f"Service unavailable! Attempt {call_count}/{REQUIRED_RETRIES}. Please retry." + ) + return "{ 'status': 'ok', 'data': [1, 2, 3] }" + + +# --------------------------------------------------------------------------- +# Template (unstable_service auto-captured from lexical scope) +# --------------------------------------------------------------------------- + + +@Template.define +def fetch_data() -> str: + """Use the unstable_service tool to fetch data.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Retry LLM template calls when tools raise exceptions" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + parser.add_argument( + "--num-retries", + type=int, + default=4, + help="Number of retries for tool/decode failures", + ) + args = parser.parse_args() + + provider = LiteLLMProvider(model=args.model) + + print("=== Without RetryLLMHandler ===") + with handler(provider): + try: + result = fetch_data() + print(f"Result: {result}") + except Exception as e: + print(f"Error: {e}") + + # Reset for the retry-enabled run. + call_count = 0 + + print("\n=== With RetryLLMHandler ===") + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): + result = fetch_data() + print(f"Result: {result} (after {call_count} tool attempts)") diff --git a/docs/source/llm_examples/retry_validation.py b/docs/source/llm_examples/retry_validation.py new file mode 100644 index 000000000..960b3f3fc --- /dev/null +++ b/docs/source/llm_examples/retry_validation.py @@ -0,0 +1,109 @@ +"""Retrying when structured-output validation fails. + +Demonstrates: +- A pydantic dataclass with ``field_validator`` constraints +- ``RetryLLMHandler`` feeding ``PydanticCustomError`` messages back to the LLM + so it can correct its output on a subsequent attempt +""" + +import argparse +import os + +import pydantic +from pydantic import field_validator +from pydantic_core import PydanticCustomError +from tenacity import stop_after_attempt + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Validated structured output +# --------------------------------------------------------------------------- + + +@pydantic.dataclasses.dataclass +class Rating: + score: int + explanation: str + + @field_validator("score") + @classmethod + def check_score(cls, v): + if v < 1 or v > 5: + raise PydanticCustomError( + "invalid_score", + "score must be 1–5, got {v}", + {"v": v}, + ) + return v + + @field_validator("explanation") + @classmethod + def check_explanation_contains_score(cls, v, info): + score = info.data.get("score", None) + if score is not None and str(score) not in v: + raise PydanticCustomError( + "invalid_explanation", + "explanation must mention the score {score}, got '{explanation}'", + {"score": score, "explanation": v}, + ) + return v + + +# --------------------------------------------------------------------------- +# Template +# --------------------------------------------------------------------------- + + +@Template.define +def give_rating_for_movie(movie_name: str) -> Rating: + """Give a rating for {movie_name}. The explanation MUST include the numeric score. Do not use any tools.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Retry on pydantic validation errors in LLM responses" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + parser.add_argument( + "--movie", type=str, default="Die Hard", help="Movie to rate" + ) + parser.add_argument( + "--num-retries", + type=int, + default=4, + help="Number of retries for malformed LLM output", + ) + args = parser.parse_args() + + provider = LiteLLMProvider(model=args.model) + + print("=== Without RetryLLMHandler ===") + with handler(provider): + try: + rating = give_rating_for_movie(args.movie) + print(f"Score: {rating.score}/5\nExplanation: {rating.explanation}") + except Exception as e: + print(f"Error: {e}") + + print("\n=== With RetryLLMHandler ===") + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + ): + rating = give_rating_for_movie(args.movie) + print(f"Score: {rating.score}/5") + print(f"Explanation: {rating.explanation}") diff --git a/docs/source/llm_examples/structured_output.py b/docs/source/llm_examples/structured_output.py new file mode 100644 index 000000000..0f6c85f88 --- /dev/null +++ b/docs/source/llm_examples/structured_output.py @@ -0,0 +1,82 @@ +"""Structured output via dataclasses. + +Demonstrates: +- Dataclass return types decoded from constrained LLM generation +- Round-tripping a dataclass: one template produces it, another consumes it +""" + +import argparse +import dataclasses +import os + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Structured output +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class KnockKnockJoke: + whos_there: str + punchline: str + + +# --------------------------------------------------------------------------- +# Templates +# --------------------------------------------------------------------------- + + +@Template.define +def write_joke(theme: str) -> KnockKnockJoke: + """Write a knock-knock joke on the theme of {theme}. Do not use any tools.""" + raise NotHandled + + +@Template.define +def rate_joke(joke: KnockKnockJoke) -> bool: + """Decide if {joke} is funny or not. Do not use any tools.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def do_comedy(theme: str) -> None: + joke = write_joke(theme) + print("> You are onstage at a comedy club. You tell the following joke:") + print( + f"Knock knock.\nWho's there?\n{joke.whos_there}.\n" + f"{joke.whos_there} who?\n{joke.punchline}" + ) + if rate_joke(joke): + print("> The crowd laughs politely.") + else: + print("> The crowd stares in stony silence.") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Structured output via dataclasses") + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + parser.add_argument( + "--theme", type=str, default="lizards", help="Theme for the joke" + ) + args = parser.parse_args() + + provider = LiteLLMProvider(model=args.model) + with handler(provider): + do_comedy(args.theme) diff --git a/docs/source/llm_examples/template_composition.py b/docs/source/llm_examples/template_composition.py new file mode 100644 index 000000000..5f8803078 --- /dev/null +++ b/docs/source/llm_examples/template_composition.py @@ -0,0 +1,76 @@ +"""Template composition: templates can call other templates. + +Demonstrates: +- Sub-templates auto-captured into an orchestrator template's lexical scope +- Inspecting ``write_story.tools`` to confirm sub-templates are exposed to the LLM +- The orchestrator dispatches to the right sub-template based on a style argument +""" + +import argparse +import os + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Sub-templates +# --------------------------------------------------------------------------- + + +@Template.define +def story_with_moral(topic: str) -> str: + """Write a short story about {topic} and end with a moral lesson. Do not use any tools.""" + raise NotHandled + + +@Template.define +def story_funny(topic: str) -> str: + """Write a funny, humorous story about {topic}. Do not use any tools.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Orchestrator template +# --------------------------------------------------------------------------- + + +@Template.define +def write_story(topic: str, style: str) -> str: + """Write a story about {topic} in the style: {style}. + Available styles: 'moral' for a story with a lesson, 'funny' for humor. + Use story_funny for humor, story_with_moral for a story with a lesson. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Template composition with auto-captured sub-templates" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + parser.add_argument( + "--topic", type=str, default="a curious cat", help="Story topic" + ) + args = parser.parse_args() + + assert story_with_moral in write_story.tools.values() + assert story_funny in write_story.tools.values() + print("Sub-templates available to write_story:", list(write_story.tools.keys())) + + provider = LiteLLMProvider(model=args.model) + with handler(provider): + print("\n=== Story with moral ===") + print(write_story(args.topic, "moral")) + print("\n=== Funny story ===") + print(write_story(args.topic, "funny")) diff --git a/docs/source/llm_examples/tool_calling.py b/docs/source/llm_examples/tool_calling.py new file mode 100644 index 000000000..f7d9b9f28 --- /dev/null +++ b/docs/source/llm_examples/tool_calling.py @@ -0,0 +1,65 @@ +"""Tool calling: templates invoke Python callables exposed via ``@Tool.define``. + +Demonstrates: +- ``@Tool.define`` for exposing a Python function to the model +- Lexical-scope auto-capture: tools defined alongside a template are made + available to the LLM without explicit registration +- The model chains multiple tool calls to answer a multi-step query +""" + +import argparse +import os + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Tools +# --------------------------------------------------------------------------- + + +@Tool.define +def cities() -> list[str]: + """Return a list of cities that can be passed to `weather`.""" + return ["Chicago", "New York", "Barcelona"] + + +@Tool.define +def weather(city: str) -> str: + """Given a city name, return a description of the weather in that city.""" + status = {"Chicago": "cold", "New York": "wet", "Barcelona": "sunny"} + return status.get(city, "unknown") + + +# --------------------------------------------------------------------------- +# Template (cities and weather are auto-captured from lexical scope) +# --------------------------------------------------------------------------- + + +@Template.define +def vacation() -> str: + """Use the provided tools to suggest a city that has good weather. Use only the `cities` and `weather` tools provided.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Tool calling with auto-captured lexical scope" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use", + ) + args = parser.parse_args() + + provider = LiteLLMProvider(model=args.model) + with handler(provider): + print(vacation()) From ed3d16ba31dc015ffe0e913cd7f65bd345c4b6a2 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 25 Apr 2026 15:10:36 -0400 Subject: [PATCH 08/11] lint --- docs/source/llm_examples/prompt_templates.py | 4 +--- docs/source/llm_examples/retry_validation.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/source/llm_examples/prompt_templates.py b/docs/source/llm_examples/prompt_templates.py index 56369056a..74c1cf0d3 100644 --- a/docs/source/llm_examples/prompt_templates.py +++ b/docs/source/llm_examples/prompt_templates.py @@ -57,9 +57,7 @@ def haiku_no_cache(theme: str) -> str: default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) - parser.add_argument( - "--theme", type=str, default="fish", help="Theme for the poem" - ) + parser.add_argument("--theme", type=str, default="fish", help="Theme for the poem") args = parser.parse_args() provider = LiteLLMProvider(model=args.model) diff --git a/docs/source/llm_examples/retry_validation.py b/docs/source/llm_examples/retry_validation.py index 960b3f3fc..aab80df9e 100644 --- a/docs/source/llm_examples/retry_validation.py +++ b/docs/source/llm_examples/retry_validation.py @@ -78,9 +78,7 @@ def give_rating_for_movie(movie_name: str) -> Rating: default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) - parser.add_argument( - "--movie", type=str, default="Die Hard", help="Movie to rate" - ) + parser.add_argument("--movie", type=str, default="Die Hard", help="Movie to rate") parser.add_argument( "--num-retries", type=int, From 3cac33d7b3ff6b1a6bdbfe91d231ec6e1de834a2 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 25 Apr 2026 15:18:14 -0400 Subject: [PATCH 09/11] nit --- docs/source/llm_examples/supervisor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/llm_examples/supervisor.py b/docs/source/llm_examples/supervisor.py index 4c07a9f83..29f258fe0 100644 --- a/docs/source/llm_examples/supervisor.py +++ b/docs/source/llm_examples/supervisor.py @@ -154,6 +154,12 @@ def supervised_research(question: str, max_retries: int = 3) -> str: default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) + parser.add_argument( + "--question", + type=str, + default="What year was the Eiffel Tower completed and how tall is it?", + help="Research question to answer", + ) parser.add_argument( "--max-retries", type=int, @@ -175,7 +181,7 @@ def supervised_research(question: str, max_retries: int = 3) -> str: handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), ): result = supervised_research( - "What year was the Eiffel Tower completed and how tall is it?", + args.question, max_retries=args.max_retries, ) print(f"\nFinal answer: {result}") From ee6fc1f27eb53d659e854b4f6e5b13a4cf2d289b Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 25 Apr 2026 15:31:18 -0400 Subject: [PATCH 10/11] retry --- docs/source/llm_examples/decode_callable.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/source/llm_examples/decode_callable.py b/docs/source/llm_examples/decode_callable.py index 05a09212a..195167bce 100644 --- a/docs/source/llm_examples/decode_callable.py +++ b/docs/source/llm_examples/decode_callable.py @@ -12,8 +12,10 @@ import os from collections.abc import Callable +from tenacity import stop_after_attempt + from effectful.handlers.llm import Template -from effectful.handlers.llm.completions import LiteLLMProvider +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler from effectful.handlers.llm.evaluation import UnsafeEvalProvider from effectful.ops.semantics import handler from effectful.ops.types import NotHandled @@ -49,6 +51,12 @@ def count_char(char: str) -> Callable[[str], int]: default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), help="LLM model to use", ) + parser.add_argument( + "--num-retries", + type=int, + default=5, + help="Number of retries for malformed LLM output", + ) parser.add_argument( "--first-digit", type=int, @@ -65,7 +73,11 @@ def count_char(char: str) -> Callable[[str], int]: provider = LiteLLMProvider(model=args.model) - with handler(provider), handler(UnsafeEvalProvider()): + with ( + handler(provider), + handler(RetryLLMHandler(stop=stop_after_attempt(args.num_retries))), + handler(UnsafeEvalProvider()), + ): prime = primes(args.first_digit) assert type(prime) is int print(f"Prime starting with {args.first_digit}: {prime}") From c2f6bbcfe5af68ba967dd3f10fdd123ff198768e Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 2 Jun 2026 15:01:09 -0400 Subject: [PATCH 11/11] Add example using object handles (#597) * add example using object handles * coerce into the same format --------- Co-authored-by: Eli --- docs/source/llm_examples/image_tool.py | 90 ++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 docs/source/llm_examples/image_tool.py diff --git a/docs/source/llm_examples/image_tool.py b/docs/source/llm_examples/image_tool.py new file mode 100644 index 000000000..813058494 --- /dev/null +++ b/docs/source/llm_examples/image_tool.py @@ -0,0 +1,90 @@ +import argparse +import os + +from PIL import Image + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import ( + LiteLLMProvider, + RetryLLMHandler, +) +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + + +class ImageTools(Agent): + """You are an image processing agent.""" + + _image_to_handle: dict[int, int] + _handle_to_image: dict[int, Image.Image] + + def __init__(self): + self._image_to_handle = {} + self._handle_to_image = {} + + def _encode(self, image: Image.Image) -> int: + image_id = id(image) + handle = self._image_to_handle.get(image_id, None) + if handle is not None: + return handle + + handle = len(self._image_to_handle) + self._image_to_handle[image_id] = handle + + assert handle not in self._handle_to_image + self._handle_to_image[handle] = image + return handle + + def _decode(self, image_handle: int) -> Image.Image: + return self._handle_to_image[image_handle] + + @Tool.define + def rotate(self, image: int, angle: float) -> int: + """Returns a rotated copy of this image. The copy is rotated by `angle` + degrees counterclockwise around the image center. + + """ + return self._encode(self._decode(image).rotate(angle)) + + @Tool.define + def concat_horiz(self, i1_h: int, i2_h: int) -> int: + """Concatenates two images horizontally. The larger image will be + cropped to the height of the smaller image. + + """ + i1 = self._decode(i1_h) + i2 = self._decode(i2_h) + i3 = Image.new("RGB", (i1.width + i2.width, min(i1.height, i2.height))) + i3.paste(i1, (0, 0)) + i3.paste(i2, (i1.width, 0)) + return self._encode(i3) + + @Template.define + def _rotate_and_concat(self, i: int) -> int: + """Create an image consisting of four copies of the image {i} + concatenated horizontally. Each copy should be rotated 90 degrees from + the previous. + + """ + raise NotHandled + + def rotate_and_concat(self, i: Image.Image) -> Image.Image: + return self._decode(self._rotate_and_concat(self._encode(i))) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default=os.environ.get("EFFECTFUL_LLM_MODEL", ""), + help="LLM model to use (must support image inputs)", + ) + args = parser.parse_args() + + image_agent = ImageTools() + img = Image.open("../_static/img/chirho_logo_wide.png") + + provider = LiteLLMProvider(model=args.model) + with handler(provider), handler(RetryLLMHandler()): + image_agent.rotate_and_concat(img).show()