Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,6 @@ CLAUDE.MD
htmlcov/
.coverage
.coverage.*

# Created from simulation
MUJOCO_LOG.TXT
Comment thread
Dreamsorcerer marked this conversation as resolved.
98 changes: 98 additions & 0 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from dimos.core.module_coordinator import ModuleCoordinator
from dimos.core.stream import In, Out
from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport
from dimos.protocol.pubsub.impl.lcmpubsub import LCM
from dimos.record.record_replay import RecordReplay
from dimos.spec.utils import Spec, is_spec, spec_annotation_compliance, spec_structural_compliance
from dimos.utils.generic import short_id
from dimos.utils.logging_config import setup_logger
Expand Down Expand Up @@ -442,6 +444,79 @@ def _connect_module_refs(self, module_coordinator: ModuleCoordinator) -> None:
setattr(base_module_proxy, module_ref_name, proxy)
base_module_proxy.set_module_ref(module_ref_name, cast("Any", proxy))

def replay(
self,
recording: RecordReplay | str,
*,
speed: float = 1.0,
cli_config_overrides: Mapping[str, Any] | None = None,
) -> ModuleCoordinator:
"""Build the blueprint with a recording providing some module outputs.

Modules whose OUT streams are fully covered by the recording are
disabled — their data comes from the recording instead. All other
modules run normally.

Args:
recording: A :class:`RecordReplay` instance, or a str
to a ``.db`` recording file.
speed: Playback speed multiplier (1.0 = realtime).
cli_config_overrides: Extra global config overrides.

Returns:
The running :class:`ModuleCoordinator`.
"""
if isinstance(recording, str):
recording = RecordReplay(recording)

recorded_streams = set(recording.store.list_streams())
if not recorded_streams:
raise ValueError("Recording is empty — no streams to replay")

# Find modules whose OUTs overlap with the recording.
# If ANY OUTs are covered, disable the module — the recording
# replaces it. Uncovered OUTs (e.g. on SHM, or never published)
# are simply absent during replay; downstream modules that need
# them won't receive data, which is the expected degraded mode.
modules_to_disable: list[type[ModuleBase]] = []
for bp in self.blueprints:
out_names = {conn.name for conn in bp.streams if conn.direction == "out"}
if not out_names:
continue
covered = out_names & recorded_streams
if covered:
modules_to_disable.append(bp.module)
uncovered = out_names - covered
if uncovered:
logger.warning(
"Replay: disabling %s (partial coverage: replaying %s, missing %s)",
bp.module.__name__,
covered,
uncovered,
)
else:
logger.info(
"Replay: disabling %s (all OUTs covered)",
bp.module.__name__,
)

if not modules_to_disable:
logger.warning(
"Replay: no modules disabled — recording streams %s "
"don't match any module OUT names",
recorded_streams,
)

patched = self.disabled_modules(*modules_to_disable)
coordinator = patched.build(cli_config_overrides)

# Start playback in background — publishes to LCM so other modules receive data
lcm = LCM()
lcm.start()
recording.play(pubsub=lcm, speed=speed)

return coordinator

def build(
self,
cli_config_overrides: Mapping[str, Any] | None = None,
Expand All @@ -451,6 +526,29 @@ def build(
if cli_config_overrides:
global_config.update(**dict(cli_config_overrides))

# Auto-replay if --replay-file is set in global config
replay_file = global_config.replay_file
if replay_file:
logger.info("Auto-replay from %s", replay_file)
# Strip replay_file from all override sources so the nested
# build() inside replay() does not re-enter this branch.
global_config.replay_file = None
clean_cli = (
{k: v for k, v in cli_config_overrides.items() if k != "replay_file"}
if cli_config_overrides
else None
)
clean_bp = replace(
self,
global_config_overrides=MappingProxyType(
{k: v for k, v in self.global_config_overrides.items() if k != "replay_file"}
),
)
return clean_bp.replay(
replay_file,
cli_config_overrides=clean_cli,
)

self._run_configurators()
self._check_requirements()
self._verify_no_name_conflicts()
Expand Down
1 change: 1 addition & 0 deletions dimos/core/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class GlobalConfig(BaseSettings):
simulation: bool = False
replay: bool = False
replay_dir: str = "go2_sf_office"
replay_file: str | None = None
new_memory: bool = False
viewer: ViewerBackend = "rerun"
n_workers: int = 2
Expand Down
8 changes: 4 additions & 4 deletions dimos/core/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

import threading
import asyncio
from typing import TYPE_CHECKING, Any, TypeAlias

from dimos.core.global_config import GlobalConfig, global_config
Expand Down Expand Up @@ -162,10 +162,10 @@ def start_all_modules(self) -> None:
def get_instance(self, module: type[ModuleBase]) -> ModuleProxy:
return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return]

def loop(self) -> None:
stop = threading.Event()
async def loop(self) -> None:
Comment thread
leshy marked this conversation as resolved.
Outdated
stop = asyncio.Event()
try:
stop.wait()
await stop.wait()
except KeyboardInterrupt:
return
finally:
Expand Down
12 changes: 12 additions & 0 deletions dimos/memory2/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

from contextlib import suppress
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Generic, TypeVar

Expand Down Expand Up @@ -220,6 +221,17 @@ def _iterate_live(
finally:
sub.dispose()

def delete_range(self, t1: float, t2: float) -> int:
"""Delete observations in [t1, t2] from all stores. Returns count deleted."""
ids = self.metadata_store.delete_range(t1, t2)
for obs_id in ids:
if self.blob_store is not None:
with suppress(KeyError):
self.blob_store.delete(self.name, obs_id)
if self.vector_store is not None:
self.vector_store.delete(self.name, obs_id)
return len(ids)

def count(self, query: StreamQuery) -> int:
if query.search_vec:
return sum(1 for _ in self.iterate(query))
Expand Down
4 changes: 2 additions & 2 deletions dimos/memory2/codecs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import importlib
from typing import Any, Protocol, TypeVar, runtime_checkable

from dimos.msgs.sensor_msgs.Image import Image

T = TypeVar("T")


Expand All @@ -33,8 +35,6 @@ def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]:
from dimos.memory2.codecs.pickle import PickleCodec

if payload_type is not None:
from dimos.msgs.sensor_msgs.Image import Image

if issubclass(payload_type, Image):
from dimos.memory2.codecs.jpeg import JpegCodec

Expand Down
4 changes: 4 additions & 0 deletions dimos/memory2/observationstore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,9 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]:
"""Batch fetch by id (for vector search results)."""
...

def delete_range(self, t1: float, t2: float) -> list[int]:
"""Delete observations with ts in [t1, t2]. Returns deleted IDs."""
raise NotImplementedError
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make it an @abstractmethod?


def serialize(self) -> dict[str, Any]:
return {"class": qual(type(self)), "config": self.config.model_dump()}
8 changes: 8 additions & 0 deletions dimos/memory2/observationstore/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,11 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]:
id_set = set(ids)
with self._lock:
return [obs for obs in self._observations if obs.id in id_set]

def delete_range(self, t1: float, t2: float) -> list[int]:
"""Delete observations with ts in [t1, t2]. Returns deleted IDs."""
with self._lock:
to_delete = [obs for obs in self._observations if t1 <= obs.ts <= t2]
ids = [obs.id for obs in to_delete]
self._observations = [obs for obs in self._observations if not (t1 <= obs.ts <= t2)]
return ids
16 changes: 16 additions & 0 deletions dimos/memory2/observationstore/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,5 +440,21 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]:
rows = self._conn.execute(sql, ids).fetchall()
return [self._row_to_obs(r, has_blob=join) for r in rows]

def delete_range(self, t1: float, t2: float) -> list[int]:
"""Delete observations with ts in [t1, t2]. Returns deleted IDs."""
with self._lock:
rows = self._conn.execute(
f'SELECT id FROM "{self._name}" WHERE ts >= ? AND ts <= ?', (t1, t2)
).fetchall()
ids = [r[0] for r in rows]
if ids:
placeholders = ",".join("?" * len(ids))
self._conn.execute(f'DELETE FROM "{self._name}" WHERE id IN ({placeholders})', ids)
self._conn.execute(
f'DELETE FROM "{self._name}_rtree" WHERE id IN ({placeholders})', ids
)
self._conn.commit()
Comment on lines +445 to +455
Copy link
Copy Markdown
Contributor

@paul-nechifor paul-nechifor Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use DELETE-RETURNING and use 2 executes instead of 3 (i.e. merge the SELECT and DELETE into a single statement).

rows = self._conn.execute(
    f'DELETE FROM "{self._name}" WHERE ts >= ? AND ts <= ? RETURNING id',
    (t1, t2),
).fetchall()

return ids

def stop(self) -> None:
super().stop()
6 changes: 6 additions & 0 deletions dimos/memory2/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ def subscribe(
on_completed=on_completed,
)

def delete_range(self, t1: float, t2: float) -> int:
"""Delete all observations with timestamps in [t1, t2]. Returns count deleted."""
if isinstance(self._source, Stream):
raise TypeError("Cannot delete from a transform stream.")
return self._source.delete_range(t1, t2)

def append(
self,
payload: T,
Expand Down
5 changes: 2 additions & 3 deletions dimos/protocol/pubsub/impl/lcmpubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass
import re
import threading
from typing import Any

from dimos.msgs.protocol import DimosMsg
from dimos.protocol.pubsub.encoders import (
Expand Down Expand Up @@ -73,7 +72,7 @@ def from_channel_str(channel: str, default_lcm_type: type[DimosMsg] | None = Non
return Topic(topic=topic_str, lcm_type=lcm_type or default_lcm_type)


class LCMPubSubBase(LCMService, AllPubSub[Topic, Any]):
class LCMPubSubBase(LCMService, AllPubSub[Topic, bytes]):
"""LCM-based PubSub with native regex subscription support.

LCM natively supports regex patterns in subscribe(), so we implement
Expand All @@ -92,7 +91,7 @@ def publish(self, topic: Topic | str, message: bytes) -> None:
topic_str = str(topic) if isinstance(topic, Topic) else topic
self.l.publish(topic_str, message)

def subscribe_all(self, callback: Callable[[bytes, Topic], Any]) -> Callable[[], None]:
def subscribe_all(self, callback: Callable[[bytes, Topic], None]) -> Callable[[], None]:
Comment thread
leshy marked this conversation as resolved.
return self.subscribe(Topic(re.compile(".*")), callback) # type: ignore[arg-type]

def subscribe(
Expand Down
2 changes: 1 addition & 1 deletion dimos/protocol/pubsub/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,6 @@ class SubscribeAllCapable(Protocol[MsgT_co, TopicT_co]):
Both AllPubSub (native) and DiscoveryPubSub (synthesized) satisfy this.
"""

def subscribe_all(self, callback: Callable[[Any, Any], Any]) -> Callable[[], None]:
def subscribe_all(self, callback: Callable[[MsgT_co, TopicT_co], None]) -> Callable[[], None]:
"""Subscribe to all messages on all topics."""
...
17 changes: 17 additions & 0 deletions dimos/record/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dimos.record.record_replay import RecordReplay

__all__ = ("RecordReplay",)
Loading
Loading