Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,26 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:

agent = agent_class(env, agent_config)

for step in range(agent.max_steps):
result = agent.select_action()
observation, _ = env.reset()
for _ in range(agent.max_steps):
result = agent.select_action(observation=observation)
if result is None:
break
step, action = result
env.test_run.step = step
logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}")
observation, reward, *_ = env.step(action)
feedback = {"trial_index": step, "value": reward}
agent.update_policy(feedback)
prev_observation = observation
observation, reward, done, *_ = env.step(action)
agent.update_policy(
{
"trial_index": step,
"value": reward,
"observation": observation,
"prev_observation": prev_observation,
"action": action,
"done": done,
}
)
logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}")

if args.mode == "run":
Expand Down
8 changes: 7 additions & 1 deletion src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,16 @@ def configure(self, config: dict[str, Any]) -> None:
pass

@abstractmethod
def select_action(self) -> tuple[int, dict[str, Any]]:
def select_action(self, observation: list[float] | None = None) -> tuple[int, dict[str, Any]]:
"""
Select an action from the action space.

Args:
observation: Latest observation produced by the environment (``env.reset()`` on the
first call, then the result of the prior ``env.step()``). Stateless agents such
as grid search or Bayesian optimization may ignore this; observation-conditioned
agents (RL, contextual bandits) should use it.

Returns:
Tuple[int, Dict[str, Any]]: The current step index and a dictionary mapping action keys to selected values.
"""
Expand Down
5 changes: 4 additions & 1 deletion src/cloudai/configurator/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,17 @@ def get_all_combinations(self) -> List[Dict[str, Any]]:
keys = list(self.action_space.keys())
return [dict(zip(keys, combination, strict=True)) for combination in self.action_combinations]

def select_action(self) -> Tuple[int, Dict[str, Any]]:
def select_action(self, observation: list[float] | None = None) -> Tuple[int, Dict[str, Any]]:
"""
Select the next action from the grid.

Grid search is stateless and does not consume the observation.

Returns:
Tuple[int, Dict[str, Any]]: The current step and a dictionary mapping action keys to selected
values.
"""
del observation
action = dict(zip(self.action_space.keys(), self.action_combinations[self.index], strict=True))
self.index += 1
step = self.index
Expand Down
2 changes: 1 addition & 1 deletion tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_config_class() -> type[StubAgentConfig]:
def configure(self, config: dict[str, Any]) -> None:
raise NotImplementedError

def select_action(self) -> tuple[int, dict[str, Any]]:
def select_action(self, observation: list[float] | None = None) -> tuple[int, dict[str, Any]]:
raise NotImplementedError

def update_policy(self, _feedback: dict[str, Any]) -> None:
Expand Down
Loading