From 5fcec32bdd07be721142dc112bbd8a336084ac71 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Fri, 15 May 2026 16:54:36 -0400 Subject: [PATCH 1/2] feat(configurator): make select_action observation-aware - BaseAgent.select_action(observation: Optional[list[float]] = None) so observation-conditioned agents (RL, contextual bandits) can use the latest env observation; stateless agents (grid search, BO) keep the same behavior. - handle_dse_job seeds the loop with env.reset(), threads observation into select_action, and forwards the gymnasium-style transition (observation, prev_observation, action, done) alongside the existing trial_index / value feedback to update_policy. - GridSearchAgent.select_action updated to the new signature. --- src/cloudai/cli/handlers.py | 48 +++++++++++++++++++++---- src/cloudai/configurator/base_agent.py | 12 +++++-- src/cloudai/configurator/grid_search.py | 9 +++-- tests/test_handlers.py | 4 +-- 4 files changed, 60 insertions(+), 13 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 0284fcd9e..e3ad1d437 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional from unittest.mock import Mock import toml @@ -118,6 +118,32 @@ def prepare_installation( return installables, installer +def _build_step_feedback( + *, + step: int, + action: dict[str, Any], + reward: float, + observation: list[float], + prev_observation: list[float], + done: bool, +) -> dict[str, Any]: + """ + Assemble the per-step feedback dict consumed by ``BaseAgent.update_policy``. + + ``trial_index`` / ``value`` preserve the historical contract used by Bayesian + optimisation and grid search; the remaining keys expose the gymnasium-style + transition for observation-conditioned agents. + """ + return { + "trial_index": step, + "value": reward, + "observation": observation, + "prev_observation": prev_observation, + "action": action, + "done": done, + } + + def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: registry = Registry() @@ -157,16 +183,26 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent = agent_class(env, agent_config) - for step in range(agent.max_steps): - result = agent.select_action() + observation, _ = env.reset() + for _ in range(agent.max_steps): + result = agent.select_action(observation=observation) if result is None: break step, action = result env.test_run.step = step logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}") - observation, reward, *_ = env.step(action) - feedback = {"trial_index": step, "value": reward} - agent.update_policy(feedback) + prev_observation = observation + observation, reward, done, *_ = env.step(action) + agent.update_policy( + _build_step_feedback( + step=step, + action=action, + reward=reward, + observation=observation, + prev_observation=prev_observation, + done=done, + ) + ) logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}") if args.mode == "run": diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index f7fafbd99..630096823 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -15,12 +15,14 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Literal +from typing import Any, Dict, Literal, Optional from pydantic import BaseModel, ConfigDict, Field from .base_gym import BaseGym +Observation = list[float] + class RewardOverrides(BaseModel): """Optional reward and observation overrides for the agent.""" @@ -87,10 +89,16 @@ def configure(self, config: dict[str, Any]) -> None: pass @abstractmethod - def select_action(self) -> tuple[int, dict[str, Any]]: + def select_action(self, observation: Optional[Observation] = None) -> tuple[int, dict[str, Any]]: """ Select an action from the action space. + Args: + observation: Latest observation produced by the environment (``env.reset()`` on the + first call, then the result of the prior ``env.step()``). Stateless agents such + as grid search or Bayesian optimization may ignore this; observation-conditioned + agents (RL, contextual bandits) should use it. + Returns: Tuple[int, Dict[str, Any]]: The current step index and a dictionary mapping action keys to selected values. """ diff --git a/src/cloudai/configurator/grid_search.py b/src/cloudai/configurator/grid_search.py index 631660ca4..05c2275e1 100644 --- a/src/cloudai/configurator/grid_search.py +++ b/src/cloudai/configurator/grid_search.py @@ -15,9 +15,9 @@ # limitations under the License. import itertools -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple -from .base_agent import BaseAgent, BaseAgentConfig +from .base_agent import BaseAgent, BaseAgentConfig, Observation from .cloudai_gym import CloudAIGymEnv @@ -71,14 +71,17 @@ def get_all_combinations(self) -> List[Dict[str, Any]]: keys = list(self.action_space.keys()) return [dict(zip(keys, combination, strict=True)) for combination in self.action_combinations] - def select_action(self) -> Tuple[int, Dict[str, Any]]: + def select_action(self, observation: Optional[Observation] = None) -> Tuple[int, Dict[str, Any]]: """ Select the next action from the grid. + Grid search is stateless and does not consume the observation. + Returns: Tuple[int, Dict[str, Any]]: The current step and a dictionary mapping action keys to selected values. """ + del observation action = dict(zip(self.action_space.keys(), self.action_combinations[self.index], strict=True)) self.index += 1 step = self.index diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 5124186c0..a3cd2a62d 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -16,7 +16,7 @@ import argparse from pathlib import Path -from typing import Any, ClassVar, Iterator +from typing import Any, ClassVar, Iterator, Optional from unittest.mock import MagicMock import pandas as pd @@ -60,7 +60,7 @@ def get_config_class() -> type[StubAgentConfig]: def configure(self, config: dict[str, Any]) -> None: raise NotImplementedError - def select_action(self) -> tuple[int, dict[str, Any]]: + def select_action(self, observation: Optional[list[float]] = None) -> tuple[int, dict[str, Any]]: raise NotImplementedError def update_policy(self, _feedback: dict[str, Any]) -> None: From 678738180d4c86f5ab7f6b8d6791ff0e5ff09112 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Mon, 18 May 2026 12:09:15 -0400 Subject: [PATCH 2/2] review: address @podkidyshev feedback - Inline _build_step_feedback as a dict literal at the call site in handle_dse_job; the named function added indirection without value. - Drop the Observation type alias from base_agent.py; use list[float] | None directly so the type is self-evident at every call site. - Use PEP 604 X | None notation instead of Optional[X] on touched lines. --- src/cloudai/cli/handlers.py | 44 +++++-------------------- src/cloudai/configurator/base_agent.py | 6 ++-- src/cloudai/configurator/grid_search.py | 6 ++-- tests/test_handlers.py | 4 +-- 4 files changed, 16 insertions(+), 44 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index e3ad1d437..e500b4dac 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Callable, List, Optional from unittest.mock import Mock import toml @@ -118,32 +118,6 @@ def prepare_installation( return installables, installer -def _build_step_feedback( - *, - step: int, - action: dict[str, Any], - reward: float, - observation: list[float], - prev_observation: list[float], - done: bool, -) -> dict[str, Any]: - """ - Assemble the per-step feedback dict consumed by ``BaseAgent.update_policy``. - - ``trial_index`` / ``value`` preserve the historical contract used by Bayesian - optimisation and grid search; the remaining keys expose the gymnasium-style - transition for observation-conditioned agents. - """ - return { - "trial_index": step, - "value": reward, - "observation": observation, - "prev_observation": prev_observation, - "action": action, - "done": done, - } - - def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: registry = Registry() @@ -194,14 +168,14 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: prev_observation = observation observation, reward, done, *_ = env.step(action) agent.update_policy( - _build_step_feedback( - step=step, - action=action, - reward=reward, - observation=observation, - prev_observation=prev_observation, - done=done, - ) + { + "trial_index": step, + "value": reward, + "observation": observation, + "prev_observation": prev_observation, + "action": action, + "done": done, + } ) logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}") diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index 630096823..51b3f01eb 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -15,14 +15,12 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, Literal from pydantic import BaseModel, ConfigDict, Field from .base_gym import BaseGym -Observation = list[float] - class RewardOverrides(BaseModel): """Optional reward and observation overrides for the agent.""" @@ -89,7 +87,7 @@ def configure(self, config: dict[str, Any]) -> None: pass @abstractmethod - def select_action(self, observation: Optional[Observation] = None) -> tuple[int, dict[str, Any]]: + def select_action(self, observation: list[float] | None = None) -> tuple[int, dict[str, Any]]: """ Select an action from the action space. diff --git a/src/cloudai/configurator/grid_search.py b/src/cloudai/configurator/grid_search.py index 05c2275e1..dda4ec308 100644 --- a/src/cloudai/configurator/grid_search.py +++ b/src/cloudai/configurator/grid_search.py @@ -15,9 +15,9 @@ # limitations under the License. import itertools -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple -from .base_agent import BaseAgent, BaseAgentConfig, Observation +from .base_agent import BaseAgent, BaseAgentConfig from .cloudai_gym import CloudAIGymEnv @@ -71,7 +71,7 @@ def get_all_combinations(self) -> List[Dict[str, Any]]: keys = list(self.action_space.keys()) return [dict(zip(keys, combination, strict=True)) for combination in self.action_combinations] - def select_action(self, observation: Optional[Observation] = None) -> Tuple[int, Dict[str, Any]]: + def select_action(self, observation: list[float] | None = None) -> Tuple[int, Dict[str, Any]]: """ Select the next action from the grid. diff --git a/tests/test_handlers.py b/tests/test_handlers.py index a3cd2a62d..dc1eaa57b 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -16,7 +16,7 @@ import argparse from pathlib import Path -from typing import Any, ClassVar, Iterator, Optional +from typing import Any, ClassVar, Iterator from unittest.mock import MagicMock import pandas as pd @@ -60,7 +60,7 @@ def get_config_class() -> type[StubAgentConfig]: def configure(self, config: dict[str, Any]) -> None: raise NotImplementedError - def select_action(self, observation: Optional[list[float]] = None) -> tuple[int, dict[str, Any]]: + def select_action(self, observation: list[float] | None = None) -> tuple[int, dict[str, Any]]: raise NotImplementedError def update_policy(self, _feedback: dict[str, Any]) -> None: