diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f471fa46..9e9ca9fe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: # Upgrade pip/setuptools/wheel - name: Upgrade pip and build tools run: | - python -m pip install --upgrade pip setuptools wheel + python -m pip install --upgrade pip setuptools wheel scikit-build-core pybind11 # Pin NumPy to 1.x to avoid compatibility issues - name: Install compatible NumPy diff --git a/.github/workflows/test_notebooks.yml b/.github/workflows/test_notebooks.yml index 60aaf555..5702ac99 100644 --- a/.github/workflows/test_notebooks.yml +++ b/.github/workflows/test_notebooks.yml @@ -33,7 +33,7 @@ jobs: # Upgrade pip/setuptools/wheel - name: Upgrade pip and build tools run: | - python -m pip install --upgrade pip setuptools wheel + python -m pip install --upgrade pip setuptools wheel scikit-build-core pybind11 # Pin NumPy to 1.x to avoid compatibility issues - name: Install compatible NumPy diff --git a/.gitignore b/.gitignore index 880d2581..3e290cae 100644 --- a/.gitignore +++ b/.gitignore @@ -209,4 +209,5 @@ marimo/_static/ marimo/_lsp/ __marimo__/ config-local/ -data/ \ No newline at end of file +data/ +.DS_Store diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 8804e813..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,100 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Installation and Setup - -This repository requires specific CUDA dependencies and nerfstudio integration. Use the provided setup script: - -```bash -# Install the package in development mode -bash setup.sh -``` - -For Docker-based development: -```bash -# Uses docker image: tommybotch/collab-splats:latest -git clone https://github.com/BasisResearch/collab-splats/ -cd collab-splats -bash setup.sh -``` - -## Architecture Overview - -**collab-splats** is a nerfstudio extension that enables depth/normal derivation and meshing for Gaussian Splatting models. The codebase is structured around two main architectural patterns: - -### Core Components - -1. **Models** (`collab_splats/models/`): - - `rade_gs_model.py`: Baseline depth/normal-enabled Gaussian splatting built on gsplat-rade - - `rade_features_model.py`: Extended version supporting ANN feature space splatting - -2. **Wrapper Interface** (`collab_splats/wrapper/splatter.py`): - - `Splatter` class: High-level interface for preprocessing, training, and visualization - - `SplatterConfig`: Configuration system for different splatting workflows - - Supports methods: `splatfacto`, `feature-splatting`, `rade-gs`, `rade-features` - -3. **Data Management** (`collab_splats/datamanagers/`): - - `features_datamanager.py`: Handles feature-based data loading and processing - -4. **Utilities** (`collab_splats/utils/`): - - `mesh.py`: Post-processing meshing functionality - - `segmentation.py` + `grouping.py`: Gaussian grouping and segmentation tools - - `visualization.py`: PyVista-based 3D visualization - - `camera_utils.py`: COLMAP camera integration - -### NerfStudio Integration - -The package registers two method configs with nerfstudio: -- `rade-gs`: Entry point in `collab_splats.configs.rade_gs_method:rade_gs_method` -- `rade-features`: Entry point in `collab_splats.configs.rade_features_method:rade_features_method` - -### Dependencies - -Key external dependencies: -- **gsplat-rade**: Custom CUDA kernels for depth/normal rasterization -- **meshlib**: 3D mesh processing (pinned to v3.0.6.229) -- **mobile_sam**: Segmentation backend -- **nerfstudio**: Base framework integration - -## Development Commands - -### Code Formatting -```bash -# Format code with black (line length: 120) -black . - -# Sort imports -isort . -``` - -### Testing -Tests are primarily notebook-based in `tests/` directory: -- `test_rade_gs.ipynb`: Model testing -- `test_grouping.ipynb`: Gaussian grouping functionality -- `test_meshing.ipynb`: Mesh generation testing - -Run Python tests directly: -```bash -python tests/test_grouping.py -``` - -### Example Workflows -Key examples in `examples/` directory: -- `derive_splats.ipynb`: Basic splatting pipeline -- `create_mesh.ipynb`: Mesh generation from splats -- `visualization.ipynb`: 3D visualization workflows -- `run_pipeline.py`: Batch processing script - -## Key Configuration Patterns - -The `SplatterConfig` TypedDict defines the main configuration interface: -- `file_path`: Input data path (video, images, etc.) -- `method`: Processing method selection -- `output_path`: Optional output directory (defaults to input parent) -- `frame_proportion`: Video frame sampling rate -- `overwrite`: Force reprocessing flag - -## Troubleshooting - -For visualization issues (plots not showing), check VSCode port forwarding settings as noted in the README. \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index b2eaeb83..846a214b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -128,58 +128,6 @@ RUN mkdir -p /opt/data && \ rclone copy collab-data:fieldwork_processed/2024_02_06-session_0001/SplatsSD/C0043.MP4 /opt/data/ && \ rm -f /tmp/api-key.json && \ rm -rf ~/.config/rclone -# # Build everything in conda environment --> last step is to install buildtools -# RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh && \ -# conda env create -n nerfstudio -f /tmp/env.yml && \ -# conda activate nerfstudio && \ - -# # Hack to install our version of rade_gs atm -# export CC=/usr/bin/gcc-11 && \ -# export CXX=/usr/bin/g++-11 && \ -# export CUDA_HOME=/opt/conda/envs/nerfstudio && \ -# export PATH=\${CUDA_HOME}/bin:\${PATH} && \ -# export LD_LIBRARY_PATH=\${CUDA_HOME}/lib64:\${LD_LIBRARY_PATH} && \ - -# # Install torch and cuda toolkit -# pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 && \ -# conda install -c 'nvidia/label/cuda-11.8.0' cuda-toolkit -y && \ -# pip install 'kornia>=0.6.11' && \ - -# # Install hloc -# git clone --branch master --recursive https://github.com/cvg/Hierarchical-Localization.git /opt/hloc && \ -# cd /opt/hloc && \ -# git checkout v1.4 && \ -# git submodule update --init --recursive && \ -# pip install -e . --no-cache-dir && \ -# cd ~ && \ - -# # Bump down for hloc interface -# pip install --no-cache-dir pycolmap==0.4.0 && \ - -# # Now bump back down to numpy 1.26.4 -# conda install -c conda-forge setuptools==69.5.1 'numpy<2.0.0' && \ - -# # Install tiny-cuda-nn -# pip install -v ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch && \ -# export TORCH_CUDA_ARCH_LIST=\"\$(echo \"${CUDA_ARCHITECTURES}\" | tr ';' '\n' | awk '\$0 > 70 {print substr(\$0,1,1)\".\"substr(\$0,2)}' | tr '\n' ' ' | sed 's/ \$//')\" && \ - -# # Install gsplat-rade -# pip install git+https://github.com/brian-xu/gsplat-rade.git && \ - -# # Changing to clone from github (newer features useful) -# git clone https://github.com/nerfstudio-project/nerfstudio.git /opt/nerfstudio && \ -# cd /opt/nerfstudio && \ -# pip install -e . && \ - -# # pip install nerfstudio && \ - -# # Bump the conda version back down --> nerfstudio upgrades for some reason in previous step -# conda install -c conda-forge 'numpy<2.0.0' && \ -# conda install -c conda-forge cmake>3.5 ninja gmp cgal ipykernel && \ -# pip install -r /tmp/requirements.txt" - -# # cd /opt/collab-splats && \ -# # pip install -e ." ################################################## # Get pre-built components # diff --git a/LICENSE b/LICENSE deleted file mode 100644 index bd24d4d3..00000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025 Basis Research Institute - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index c4d55fc0..227de9a2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,13 @@ Extension tools for nerfstudio enabling depth/normal derivation and meshing (among other functions) for gaussian splatting. +For more details, see the paper + + +***I. Aitsahalia et al., “Inferring cognitive strategies from groups of animals in natural environments,” presented at the NeurIPS Workshop on Data on the Brain \& Mind Findings, 2025.*** + +To reproduce the figures from the paper, see [`paper_figures.md`](./paper_figures.md) + ## Installation ### Docker @@ -11,20 +18,18 @@ We provide a docker image setup for running nerfstudio with collab-splats (along Once the docker image is loaded, please clone and install the repository as follows ```bash -## If public repository could do -- pip install git+https://github.com/BasisResearch/collab-splats -git clone https://github.com/BasisResearch/collab-splats/ -cd collab-splats - -# This performs pip install -e . -bash setup.sh +pip install git+https://github.com/BasisResearch/collab-splats ``` -For use of gcloud data interfaces, please also install collab-data - +### Standalone setup w/o CUDA ```bash -pip install git+https://github.com/BasisResearch/collab-data.git +git clone https://github.com/BasisResearch/collab-splats/ +cd collab-splats +uv venv --python=3.10 && source .venv/bin/activate && uv pip install pip +bash setup_nocuda.sh ``` + #### Building the docker image The Docker image includes an example video file (`C0043.MP4`) downloaded from Google Cloud Storage during the build process. Follow these steps to build the image: @@ -171,4 +176,4 @@ The Docker image contains an example splat video at `/opt/data/C0043.MP4`. ## Problems -Things aren't showing up in plots? Check [VSCode forwarding settings](https://github.com/pyvista/pyvista/issues/5296#issuecomment-1971079419) \ No newline at end of file +Things aren't showing up in plots? Check [VSCode forwarding settings](https://github.com/pyvista/pyvista/issues/5296#issuecomment-1971079419) diff --git a/collab_splats/datamanagers/features_datamanager.py b/collab_splats/datamanagers/features_datamanager.py index 6b3302ff..49868ea3 100644 --- a/collab_splats/datamanagers/features_datamanager.py +++ b/collab_splats/datamanagers/features_datamanager.py @@ -103,11 +103,23 @@ def setup(self) -> Dict[str, torch.Tensor]: # Try loading from cache if enabled if self.config.enable_cache and cache_path.exists(): + CONSOLE.print(f"Found cached features at {cache_path}") cache_dict = torch.load(cache_path) - if cache_dict.get("image_filenames") != image_filenames: - CONSOLE.print("Image filenames have changed, cache invalidated...") + # Normalize paths to strings for comparison (handles Path vs str mismatch) + cached_filenames = [str(p) for p in cache_dict.get("image_filenames", [])] + current_filenames = [str(p) for p in image_filenames] + + if cached_filenames != current_filenames: + CONSOLE.print( + "[yellow]Cache invalidated: image filenames have changed[/yellow]" + ) + CONSOLE.print( + f" Cached: {len(cached_filenames)} images, " + f"Current: {len(current_filenames)} images" + ) else: + CONSOLE.print("[green]✓ Loading features from cache[/green]") return cache_dict["features_dict"] else: CONSOLE.print("Cache does not exist, extracting features...") diff --git a/collab_splats/utils/__init__.py b/collab_splats/utils/__init__.py index 8d534a50..c49ad61e 100644 --- a/collab_splats/utils/__init__.py +++ b/collab_splats/utils/__init__.py @@ -1,5 +1,6 @@ from .camera_utils import ColmapCamera, convert_to_colmap_camera, depth_double_to_normal from .trainer_config import _TrainerConfig, _ExperimentConfig +from .model_loading import load_checkpoint __all__ = [ "ColmapCamera", @@ -7,4 +8,5 @@ "depth_double_to_normal", "_TrainerConfig", "_ExperimentConfig", + "load_checkpoint", ] diff --git a/collab_splats/utils/mesh.py b/collab_splats/utils/mesh.py index af022a20..29df2b3d 100644 --- a/collab_splats/utils/mesh.py +++ b/collab_splats/utils/mesh.py @@ -224,52 +224,49 @@ def features2vertex(mesh_vertices, points, features, k=5, sdf_trunc=0.03): ######################################################## -def clean_repair_mesh( - mesh_path: str, - max_hole_size: float = 3.0, - max_edge_splits: int = 10000, - use_largest: bool = False, # if True, selects only the largest -): - # Load mesh - mesh = mm.loadMesh(mesh_path) +def _filter_mesh_components(mesh, use_largest=False): + """ + Filter mesh to keep only largest component or components within bounds. - # Identify all connected components - components = mm.getAllComponents(mesh) + Args: + mesh: Input mesh + use_largest: If True, keep only the largest component - # Determine component sizes + Returns: + Filtered mesh and number of removed components + """ + components = mm.getAllComponents(mesh) sizes = [mask.count() for mask in components] - - # Always find largest cluster largest_idx = max(range(len(sizes)), key=lambda i: sizes[i]) - # Add the largest component + # Start with largest component combined = mm.Mesh() - combined.addPartByMask(mesh, components[largest_idx]) + mesh_part = mm.MeshPart(mesh, components[largest_idx]) + combined.addMeshPart(mesh_part) - # Remove the largest component from list of idxs + n_removed = 0 if not use_largest: idxs = list(range(len(sizes))) idxs.remove(largest_idx) - - # Add the remaining components if they fall within the bounds combined_bounds = combined.getBoundingBox() - n_removed = 0 - ## THIS IS REALLY HACKY AND INEFFIENCT CHANGE SOMETIME for idx in tqdm(idxs, desc="Finding components within bounds"): _temp = mm.Mesh() - _temp.addPartByMask(mesh, components[idx]) + temp_mesh_part = mm.MeshPart(mesh, components[idx]) + _temp.addMeshPart(temp_mesh_part) if combined_bounds.contains(_temp.getBoundingBox()): - # print (f"Adding component {idx} to combined mesh") - combined.addPartByMask(mesh, components[idx]) + mesh_part = mm.MeshPart(mesh, components[idx]) + combined.addMeshPart(mesh_part) else: n_removed += 1 print(f"Removed {n_removed} components") - mesh = combined + return combined, n_removed + - # Compute average edge length +def _compute_avg_edge_length(mesh): + """Compute average edge length of mesh.""" avg_edge_length = 0.0 num_edges = 0 @@ -283,28 +280,129 @@ def clean_repair_mesh( mesh.points.vec[dest.get()] - mesh.points.vec[org.get()] ).length() num_edges += 1 - avg_edge_length /= num_edges - # Fill holes - hole_ids = mesh.topology.findHoleRepresentiveEdges() + return avg_edge_length / num_edges if num_edges > 0 else 0.0 + + +def _fill_holes_advanced( + mesh, hole_ids, max_hole_size, avg_edge_length, max_edge_splits +): + """Fill holes using advanced fillHoleNicely method (MeshLib 3.0.7+).""" + print(f"Filling {len(hole_ids)} holes using advanced method...") + for he in tqdm(hole_ids, desc=f"Filling holes ({len(hole_ids)})"): + try: + perimeter = mesh.holePerimiter(he) + if perimeter < max_hole_size: + settings = mm.FillHoleNicelySettings() + settings.maxEdgeLen = avg_edge_length + settings.maxEdgeSplits = max_edge_splits + + # Set metric (required in newer versions) + try: + settings.metric = mm.getUniversalMetric(mesh) + except Exception: + pass + + mm.fillHoleNicely(mesh, he, settings) + else: + print(f"Skipping hole {he} of perimeter {perimeter}") + except Exception as e: + print(f"Warning: Failed to fill hole {he}: {e}") + try: + mm.fillHoleTrivially(mesh, he) + except Exception as e2: + print(f"Warning: Fallback fill also failed: {e2}") + + +def _fill_holes_standard( + mesh, hole_ids, max_hole_size, avg_edge_length, max_edge_splits +): + """Fill holes using standard fillHole method with subdivision.""" fill_params = mm.FillHoleParams() + # Set metric if available (required in newer versions) + try: + fill_params.metric = mm.getUniversalMetric(mesh) + except Exception: + pass + for he in tqdm(hole_ids, desc=f"Filling holes ({len(hole_ids)})"): - if mesh.holePerimiter(he) < max_hole_size: - new_faces = mm.FaceBitSet() - fill_params.outNewFaces = new_faces - mm.fillHole(mesh, he, fill_params) - - new_verts = mm.VertBitSet() - subdiv_settings = mm.SubdivideSettings() - subdiv_settings.maxEdgeLen = avg_edge_length - subdiv_settings.maxEdgeSplits = max_edge_splits - subdiv_settings.region = new_faces - subdiv_settings.newVerts = new_verts - mm.subdivideMesh(mesh, subdiv_settings) - mm.positionVertsSmoothly(mesh, new_verts) - else: - print(f"Skipping hole {he} of perimeter {mesh.holePerimiter(he)}") + try: + perimeter = mesh.holePerimiter(he) + if perimeter < max_hole_size: + new_faces = mm.FaceBitSet() + fill_params.outNewFaces = new_faces + mm.fillHole(mesh, he, fill_params) + + # Subdivide and smooth new faces + new_verts = mm.VertBitSet() + subdiv_settings = mm.SubdivideSettings() + subdiv_settings.maxEdgeLen = avg_edge_length + subdiv_settings.maxEdgeSplits = max_edge_splits + subdiv_settings.region = new_faces + subdiv_settings.newVerts = new_verts + + # Set smooth mode if available + try: + subdiv_settings.smoothMode = mm.SubdivideSettings.SmoothMode.Linear + except (AttributeError, Exception): + pass + + mm.subdivideMesh(mesh, subdiv_settings) + mm.positionVertsSmoothly(mesh, new_verts) + else: + print(f"Skipping hole {he} of perimeter {perimeter}") + except Exception as e: + print(f"Warning: Failed to fill hole {he}: {e}") + + +def clean_repair_mesh( + mesh_path: str, + max_hole_size: float = 3.0, + max_edge_splits: int = 10000, + use_largest: bool = False, + use_advanced_fill: bool = True, +): + """ + Clean and repair mesh with component filtering and hole filling. + Compatible with MeshLib 3.0.6+ and 3.0.9+ + + Args: + mesh_path: Path to the mesh file + max_hole_size: Maximum hole perimeter to fill + max_edge_splits: Maximum number of edge splits for subdivision + use_largest: If True, only keep the largest component + use_advanced_fill: If True, use fillHoleNicely (recommended for MeshLib 3.0.7+) + + Returns: + Cleaned and repaired mesh + """ + # Load mesh + mesh = mm.loadMesh(mesh_path) + + # Filter components + mesh, _ = _filter_mesh_components(mesh, use_largest) + + # Compute average edge length for subdivision + avg_edge_length = _compute_avg_edge_length(mesh) + + # Fill holes + hole_ids = mesh.topology.findHoleRepresentiveEdges() + + if use_advanced_fill: + try: + _fill_holes_advanced( + mesh, hole_ids, max_hole_size, avg_edge_length, max_edge_splits + ) + except Exception as e: + print(f"Advanced fill failed ({e}), falling back to standard method...") + _fill_holes_standard( + mesh, hole_ids, max_hole_size, avg_edge_length, max_edge_splits + ) + else: + _fill_holes_standard( + mesh, hole_ids, max_hole_size, avg_edge_length, max_edge_splits + ) return mesh @@ -1445,13 +1543,14 @@ def main(self): print("Saving splats pointcloud") means = pipeline.model.means.detach().cpu().numpy() colors = pipeline.model.features_dc.detach().cpu().numpy() + normals = pipeline.model.normals.detach().cpu().numpy() colors = SH2RGB(colors) # Pass xyz to Open3D.o3d.geometry.PointCloud and visualize pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(means) pcd.colors = o3d.utility.Vector3dVector(colors) - + pcd.normals = o3d.utility.Vector3dVector(normals) # Transform to specified coordinate system # pcd.transform(world_transform) diff --git a/collab_splats/utils/model_loading.py b/collab_splats/utils/model_loading.py new file mode 100644 index 00000000..b892b8d4 --- /dev/null +++ b/collab_splats/utils/model_loading.py @@ -0,0 +1,35 @@ +"""Utility functions for loading trained nerfstudio models.""" + +from pathlib import Path +from typing import Tuple, Union + + +def load_checkpoint( + config_path: Union[str, Path], + test_mode: str = "inference", +) -> Tuple: + """ + Load a trained nerfstudio model checkpoint. + + Args: + config_path: Path to the config.yml file from a trained model + test_mode: Evaluation mode - "test", "val", or "inference" (default) + + Returns: + Tuple of (config, pipeline, checkpoint_path, step) + + Example: + >>> from collab_splats.utils import load_checkpoint + >>> config, pipeline, ckpt_path, step = load_checkpoint("outputs/scene/rade-gs/config.yml") + >>> model = pipeline.model + >>> outputs = model.get_outputs(camera) + """ + # Import here to avoid circular dependency during plugin discovery + from nerfstudio.utils.eval_utils import eval_setup + + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + return eval_setup(config_path, test_mode=test_mode) diff --git a/collab_splats/wrapper/__init__.py b/collab_splats/wrapper/__init__.py index 0be47464..6a9a4e9a 100644 --- a/collab_splats/wrapper/__init__.py +++ b/collab_splats/wrapper/__init__.py @@ -1,6 +1,9 @@ from .splatter import Splatter, SplatterConfig +from .config import ConfigLoader, parse_cli_overrides __all__ = [ "SplatterConfig", "Splatter", + "ConfigLoader", + "parse_cli_overrides", ] diff --git a/collab_splats/wrapper/config.py b/collab_splats/wrapper/config.py new file mode 100644 index 00000000..367225a1 --- /dev/null +++ b/collab_splats/wrapper/config.py @@ -0,0 +1,158 @@ +"""Configuration loading utilities for Splatter workflows.""" + +from pathlib import Path +from typing import Dict, Any, Optional, Union +import yaml +from mergedeep import merge + + +class ConfigLoader: + """ + Load and merge hierarchical YAML configurations. + + Supports inheritance with the following priority (highest to lowest): + 1. Runtime overrides (passed programmatically) + 2. Dataset config + 3. Base config + + Each config file should contain all sections (preprocess, training, meshing). + Dataset configs inherit from base and can override any values. + """ + + def __init__(self, config_dir: Union[str, Path]): + """ + Initialize config loader. + + Args: + config_dir: Directory containing: + - base.yaml (default configuration) + - datasets/ (dataset-specific configs) + """ + self.config_dir = Path(config_dir) + if not self.config_dir.exists(): + raise ValueError(f"Config directory not found: {config_dir}") + + base_path = self.config_dir / "base.yaml" + if not base_path.exists(): + raise ValueError(f"base.yaml not found in {config_dir}") + + self.base_config = self._load_yaml(base_path) + + @staticmethod + def _load_yaml(path: Path) -> Dict[str, Any]: + """Load a YAML file.""" + if not path.exists(): + return {} + with open(path, "r") as f: + return yaml.safe_load(f) or {} + + @staticmethod + def _deep_merge(base: Dict, override: Dict) -> Dict: + """ + Deep merge two dictionaries, with override taking precedence. + + Args: + base: Base dictionary + override: Override dictionary + + Returns: + Merged dictionary + """ + # Use mergedeep for clean deep merging + return merge({}, base, override) + + def load( + self, + dataset: str, + overrides: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Load and merge configuration. + + Args: + dataset: Dataset config name (from datasets/ subdirectory) + overrides: Optional dictionary of runtime overrides + + Returns: + Merged configuration dictionary with all sections: + - preprocess: SfM and preprocessing settings + - training: Model training parameters + - meshing: Mesh generation settings + - file_path, method, etc: Top-level splatter settings + + Raises: + ValueError: If dataset config not found + """ + # Start with base + config = self.base_config.copy() + + # Merge dataset config + dataset_path = self.config_dir / "datasets" / f"{dataset}.yaml" + if not dataset_path.exists(): + raise ValueError( + f"Dataset config not found: {dataset_path}\n" + f"Available datasets: {self.list_datasets()}" + ) + dataset_config = self._load_yaml(dataset_path) + config = self._deep_merge(config, dataset_config) + + # Apply runtime overrides + if overrides: + config = self._deep_merge(config, overrides) + + return config + + def list_datasets(self) -> list: + """ + List available dataset configs. + + Returns: + List of available dataset names + """ + datasets_dir = self.config_dir / "datasets" + if not datasets_dir.exists(): + return [] + return sorted([f.stem for f in datasets_dir.glob("*.yaml")]) + + +def parse_cli_overrides(override_strings: list) -> Dict[str, Any]: + """ + Parse command-line override strings into a config dict. + + Supports nested keys using dot notation. + + Args: + override_strings: List of 'key=value' or 'section.key=value' strings + + Returns: + Dictionary of overrides + + Example: + >>> parse_cli_overrides(['method=rade-gs', 'preprocess.sfm_tool=colmap']) + {'method': 'rade-gs', 'preprocess': {'sfm_tool': 'colmap'}} + """ + overrides: Dict[str, Any] = {} + for override in override_strings: + if "=" not in override: + raise ValueError(f"Invalid override: '{override}'. Expected 'key=value'") + + key, value = override.split("=", 1) + + # Type conversion + if value.lower() == "true": + value = True + elif value.lower() == "false": + value = False + elif value.replace(".", "", 1).replace("-", "", 1).isdigit(): + value = float(value) if "." in value else int(value) + + # Handle nested keys (e.g., 'preprocess.sfm_tool=colmap') + keys = key.split(".") + current = overrides + for k in keys[:-1]: + if k not in current: + current[k] = {} + current = current[k] + current[keys[-1]] = value + + return overrides diff --git a/collab_splats/wrapper/splatter.py b/collab_splats/wrapper/splatter.py index 3e086fb7..6c9f9425 100644 --- a/collab_splats/wrapper/splatter.py +++ b/collab_splats/wrapper/splatter.py @@ -64,6 +64,11 @@ def __init__(self, config: SplatterConfig): validated_config = self.validate_config(config) self.config: Dict[str, Any] = dict(validated_config) + # Optional pipeline configs (set by from_config_file) + self._preprocess_config: Optional[Dict[str, Any]] = None + self._training_config: Optional[Dict[str, Any]] = None + self._meshing_config: Optional[Dict[str, Any]] = None + @classmethod def validate_config(cls, config: SplatterConfig) -> SplatterConfig: """Validate the splatter configuration. @@ -129,13 +134,121 @@ def available_methods(cls) -> None: print("Available methods:") print(" ", sorted(cls.SPLATTING_METHODS)) + @classmethod + def from_config_file( + cls, + dataset: str, + config_dir: Union[str, Path], + overrides: Optional[Dict[str, Any]] = None, + ) -> "Splatter": + """ + Create Splatter instance from YAML configuration. + + Args: + dataset: Dataset config name (from datasets/ subdirectory) + config_dir: Directory containing config files (base.yaml and datasets/) + overrides: Optional runtime overrides + + Returns: + Configured Splatter instance with pipeline configs attached + + Example: + >>> splatter = Splatter.from_config_file( + ... dataset='ants_001', + ... config_dir='docs/splats/configs' + ... ) + >>> splatter.run_pipeline(overwrite=True) + """ + from collab_splats.wrapper.config import ConfigLoader + + loader = ConfigLoader(config_dir) + config = loader.load(dataset=dataset, overrides=overrides) + + # Store full config for later use + full_config = config.copy() + + # Extract SplatterConfig fields + splatter_fields: Dict[str, Any] = { + "file_path": config["file_path"], + "method": config["method"], + } + # Add optional fields if present + if "frame_proportion" in config: + splatter_fields["frame_proportion"] = config["frame_proportion"] + if "min_frames" in config: + splatter_fields["min_frames"] = config["min_frames"] + if "output_path" in config: + splatter_fields["output_path"] = config["output_path"] + + splatter_config: SplatterConfig = splatter_fields # type: ignore + instance = cls(splatter_config) + + # Attach configs for pipeline methods + instance._preprocess_config = full_config.get("preprocess", {}) + instance._training_config = full_config.get("training", {}) + instance._meshing_config = full_config.get("meshing", {}) + + return instance + + def run_pipeline(self, overwrite: bool = False) -> None: + """ + Run complete pipeline using stored configurations. + + This method runs preprocessing, training, and meshing using + configurations loaded via from_config_file(). + + Args: + overwrite: Whether to overwrite existing outputs + + Raises: + ValueError: If pipeline configs not found (must use from_config_file) + """ + if self._preprocess_config is None: + raise ValueError( + "Pipeline configs not found. Use Splatter.from_config_file() " + "to load configurations before calling run_pipeline()" + ) + + print(f"\n{'=' * 80}") + print(f"Running {self.config['method']} pipeline") + print(f"File: {Path(self.config['file_path']).name}") + print(f"{'=' * 80}\n") + + # Step 1: Preprocessing + print("[1/3] Preprocessing...") + self.preprocess(kwargs=self._preprocess_config, overwrite=overwrite) + + # Step 2: Training + print("\n[2/3] Training...") + self.extract_features(kwargs=self._training_config, overwrite=overwrite) + + # Step 3: Meshing + print("\n[3/3] Meshing...") + mesher_config = (self._meshing_config or {}).copy() + mesher_type = mesher_config.pop("mesher_type", "Open3DTSDFFusion") + self.mesh( + mesher_type=mesher_type, mesher_kwargs=mesher_config, overwrite=overwrite + ) + + print(f"\n{'=' * 80}") + print("Pipeline complete!") + print(f"{'=' * 80}\n") + def preprocess( - self, overwrite: bool = False, kwargs: Optional[Dict[str, Any]] = None + self, + overwrite: bool = False, + sfm_tool: str = "colmap", + kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Preprocess the data in the splatter. This function handles any necessary data preprocessing steps based on the configured method. + + Args: + overwrite: If True, rerun preprocessing even if transforms.json exists + sfm_tool: Structure from motion tool to use ('colmap', 'hloc') + kwargs: Additional arguments to pass to ns-process-data """ file_path = self.config["file_path"] output_path = self.config["output_path"] @@ -174,6 +287,8 @@ def preprocess( # If we have less than the minimum number of frames, as many as possible n_samples = n_frames if n_samples < self.config["min_frames"] else n_samples + print("Number of frames to sample: ", n_samples) + # Create the command num_frames_target = f"--num-frames-target {n_samples}" else: @@ -185,6 +300,7 @@ def preprocess( f"{input_type} " f"--data {file_path.as_posix()} " f"--output-dir {preproc_data_path.as_posix()} " + f"--sfm-tool {sfm_tool} " f"{num_frames_target} " ) @@ -297,6 +413,52 @@ def _select_run(self) -> None: self.config["model_path"] = selected_run.as_posix() self.config["model_config_path"] = (selected_run / "config.yml").as_posix() + def load_model( + self, + config_path: Optional[Union[str, Path]] = None, + test_mode: str = "inference", + ): + """ + Load a trained nerfstudio model. + + Args: + config_path: Path to config.yml. If None, uses model_config_path from config or prompts selection + test_mode: Evaluation mode - "test", "val", or "inference" (default) + + Returns: + Tuple of (config, pipeline, model) + + Example: + >>> splatter = Splatter(config) + >>> config, pipeline, model = splatter.load_model("outputs/scene/rade-gs/config.yml") + >>> outputs = model.get_outputs(camera) + """ + from collab_splats.utils import load_checkpoint + + # Determine config path + if config_path is None: + if not self.config.get("model_config_path"): + self._select_run() + config_path = self.config["model_config_path"] + + print(f"Loading model from {config_path}") + + # Load using utility function + config, pipeline, checkpoint_path, step = load_checkpoint( + config_path, test_mode=test_mode + ) + + # Store in instance + self.model = pipeline.model + self.pipeline = pipeline + self.model_config = config + self.checkpoint_path = checkpoint_path + self.training_step = step + + print(f"✓ Model loaded: {type(self.model).__name__} (step {step})") + + return config, pipeline, self.model + def mesh( self, mesher_type: str = "Open3DTSDFFusion", @@ -309,7 +471,8 @@ def mesh( """ self._select_run() - mesh_dir = self.config["output_path"] / self.config["method"] / "mesh" + # Save mesh under the selected run directory + mesh_dir = Path(self.config["model_path"]) / "mesh" # Create the mesh if not mesh_dir.exists() or overwrite: @@ -317,6 +480,10 @@ def mesh( print(f"Initializing mesher {mesher_type}") + # Handle None mesher_kwargs + if mesher_kwargs is None: + mesher_kwargs = {} + # Initialize the mesher mesher = getattr(mesh, mesher_type)( load_config=Path(self.config["model_config_path"]), @@ -341,14 +508,8 @@ def query_mesh( ) -> None: """Query the mesh for features.""" - if not self.config.get("model_config_path"): - self._select_run() - elif getattr(self, "model", None) is None: - print(f"Loading model from {self.config['model_config_path']}") - from nerfstudio.utils.eval_utils import eval_setup - - _, pipeline, _, _ = eval_setup(Path(self.config["model_config_path"])) - self.model = pipeline.model + if getattr(self, "model", None) is None: + self.load_model() mesh_info = self.config.get("mesh_info") if mesh_info is None: diff --git a/docs/splats/configs/base.yaml b/docs/splats/configs/base.yaml new file mode 100644 index 00000000..a6951ef8 --- /dev/null +++ b/docs/splats/configs/base.yaml @@ -0,0 +1,41 @@ +# Base configuration for Splatter workflows +# All dataset configs inherit from this and can override any values + +# ============================================================================ +# Splatter Settings +# ============================================================================ +method: rade-features +frame_proportion: 0.25 +min_frames: 100 + +# ============================================================================ +# Preprocessing (Structure from Motion) +# ============================================================================ +preprocess: + sfm_tool: hloc + # refine_pixsfm: true # Optional: Enable pixel-perfect SfM refinement + +# ============================================================================ +# Training (Model & Pipeline Parameters) +# ============================================================================ +training: + pipeline.model.output-depth-during-training: true + pipeline.model.rasterize-mode: antialiased + pipeline.model.use-scale-regularization: true + # pipeline.model.random-scale: 1.0 + # pipeline.model.cull-alpha-thresh: 0.01 + # pipeline.model.collider-params: near_plane 0.1 far_plane 1.0 + +# ============================================================================ +# Meshing (3D Reconstruction) +# ============================================================================ +meshing: + mesher_type: Open3DTSDFFusion + depth_name: median_depth + depth_trunc: 1.0 + voxel_size: 0.01 + normals_name: normals + features_name: distill_features + sdf_trunc: 0.03 + clean_repair: true + align_floor: true diff --git a/docs/splats/configs/datasets/ants_date-11162025_video-GH010210.yaml b/docs/splats/configs/datasets/ants_date-11162025_video-GH010210.yaml new file mode 100644 index 00000000..15e0b990 --- /dev/null +++ b/docs/splats/configs/datasets/ants_date-11162025_video-GH010210.yaml @@ -0,0 +1,9 @@ +# Ants dataset - November 16, 2025 - GH010210 +# Inherits from base.yaml and overrides specific values + +file_path: /workspace/fieldwork-data/ants/2025-11-16/SplatsSD/GH010210.MP4 +frame_proportion: 0.08 + +# Optional: Override training parameters for this specific dataset +# training: +# pipeline.model.random-scale: 1.5 diff --git a/docs/splats/configs/datasets/birds_date-02062024_video-C0043.yaml b/docs/splats/configs/datasets/birds_date-02062024_video-C0043.yaml new file mode 100644 index 00000000..80f1275f --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-02062024_video-C0043.yaml @@ -0,0 +1,4 @@ +# Birds dataset - February 6, 2024 - C0043 + +file_path: /workspace/fieldwork-data/birds/2024-02-06/SplatsSD/C0043.MP4 +frame_proportion: 0.25 diff --git a/docs/splats/configs/datasets/birds_date-05182024_video-C0065.yaml b/docs/splats/configs/datasets/birds_date-05182024_video-C0065.yaml new file mode 100644 index 00000000..d6f35a64 --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-05182024_video-C0065.yaml @@ -0,0 +1,4 @@ +# Birds dataset - May 18, 2024 - C0065 + +file_path: /workspace/fieldwork-data/birds/2024-05-18/SplatsSD/C0065.MP4 +frame_proportion: 0.25 diff --git a/docs/splats/configs/datasets/birds_date-05192024_video-C0067.yaml b/docs/splats/configs/datasets/birds_date-05192024_video-C0067.yaml new file mode 100644 index 00000000..7b570295 --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-05192024_video-C0067.yaml @@ -0,0 +1,4 @@ +# Birds dataset - May 19, 2024 - C0067 + +file_path: /workspace/fieldwork-data/birds/2024-05-19/SplatsSD/C0067.MP4 +frame_proportion: 0.25 diff --git a/docs/splats/configs/datasets/birds_date-05232024_video-GH010070.yaml b/docs/splats/configs/datasets/birds_date-05232024_video-GH010070.yaml new file mode 100644 index 00000000..6eb130a2 --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-05232024_video-GH010070.yaml @@ -0,0 +1,4 @@ +# Birds dataset - May 23, 2024 - GH010070 + +file_path: /workspace/fieldwork-data/birds/2024-05-23/SplatsSD/GH010070.MP4 +frame_proportion: 0.125 diff --git a/docs/splats/configs/datasets/birds_date-05272024_video-GH010097.yaml b/docs/splats/configs/datasets/birds_date-05272024_video-GH010097.yaml new file mode 100644 index 00000000..8e8adf2f --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-05272024_video-GH010097.yaml @@ -0,0 +1,4 @@ +# Birds dataset - May 27, 2024 - GH010097 + +file_path: /workspace/fieldwork-data/birds/2024-05-27/SplatsSD/GH010097.MP4 +frame_proportion: 0.14 diff --git a/docs/splats/configs/datasets/birds_date-05272024_video-GH010105.yaml b/docs/splats/configs/datasets/birds_date-05272024_video-GH010105.yaml new file mode 100644 index 00000000..ed17b1d7 --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-05272024_video-GH010105.yaml @@ -0,0 +1,4 @@ +# Birds dataset - May 27, 2024 - GH010105 + +file_path: /workspace/fieldwork-data/birds/2024-05-27/SplatsSD/GH010105.MP4 +frame_proportion: 0.25 diff --git a/docs/splats/configs/datasets/birds_date-06012024_video-GH010164.yaml b/docs/splats/configs/datasets/birds_date-06012024_video-GH010164.yaml new file mode 100644 index 00000000..39b28d9e --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-06012024_video-GH010164.yaml @@ -0,0 +1,4 @@ +# Birds dataset - June 1, 2024 - GH010164 + +file_path: /workspace/fieldwork-data/birds/2024-06-01/SplatsSD/GH010164.MP4 +frame_proportion: 0.1 diff --git a/docs/splats/configs/datasets/birds_date-11052023_video-PXL_20231105_154956078.yaml b/docs/splats/configs/datasets/birds_date-11052023_video-PXL_20231105_154956078.yaml new file mode 100644 index 00000000..90057961 --- /dev/null +++ b/docs/splats/configs/datasets/birds_date-11052023_video-PXL_20231105_154956078.yaml @@ -0,0 +1,4 @@ +# Birds dataset - November 5, 2023 - PXL_20231105_154956078 + +file_path: /workspace/fieldwork-data/birds/2023-11-05/SplatsSD/PXL_20231105_154956078.mp4 +frame_proportion: 0.25 diff --git a/docs/splats/configs/datasets/example.yaml b/docs/splats/configs/datasets/example.yaml new file mode 100644 index 00000000..4629b9a6 --- /dev/null +++ b/docs/splats/configs/datasets/example.yaml @@ -0,0 +1,20 @@ +# Example dataset configuration using Docker image video +# Inherits from base.yaml + +file_path: /opt/data/C0043.MP4 +frame_proportion: 0.25 + +# Example: Override preprocessing to use colmap instead of hloc +# preprocess: +# sfm_tool: colmap + +# Example: Use higher quality training settings +# training: +# pipeline.model.random-scale: 1.0 +# pipeline.model.cull-alpha-thresh: 0.01 +# pipeline.model.collider-params: near_plane 0.1 far_plane 1.0 + +# Example: Higher resolution meshing +# meshing: +# voxel_size: 0.005 +# sdf_trunc: 0.015 diff --git a/docs/splats/configs/datasets/rats_date-07112024_video-C0119.yaml b/docs/splats/configs/datasets/rats_date-07112024_video-C0119.yaml new file mode 100644 index 00000000..84e53de7 --- /dev/null +++ b/docs/splats/configs/datasets/rats_date-07112024_video-C0119.yaml @@ -0,0 +1,4 @@ +# Rats dataset - July 11, 2024 - C0119 + +file_path: /workspace/fieldwork-data/rats/2024-07-11/SplatsSD/C0119.MP4 +frame_proportion: 0.25 diff --git a/docs/splats/create_mesh.ipynb b/docs/splats/create_mesh.ipynb index 3678aa5f..f545b5d9 100644 --- a/docs/splats/create_mesh.ipynb +++ b/docs/splats/create_mesh.ipynb @@ -17,35 +17,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Jupyter environment detected. Enabling Open3D WebVisualizer.\n", - "[Open3D INFO] WebRTC GUI backend enabled.\n", - "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/envs/nerfstudio/lib/python3.10/site-packages/pyvista/plotting/utilities/xvfb.py:48: PyVistaDeprecationWarning: This function is deprecated and will be removed in future version of PyVista. Use vtk-osmesa instead.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ - "import os\n", - "import sys\n", - "from pathlib import Path\n", - "\n", "import pyvista as pv\n", - "\n", - "from collab_splats.wrapper import Splatter, SplatterConfig\n", + "from collab_splats.wrapper import Splatter\n", "\n", "pv.start_xvfb()" ] @@ -54,132 +31,66 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Set paths to the file for running splats" + "## Load the splatter from configuration\n", + "\n", + "Load the dataset configuration from YAML and ensure preprocessing/training steps are complete:" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "transforms.json already exists at /workspace/fieldwork-data/birds/2024-02-06/environment/C0043/preproc/transforms.json\n", - "To rerun preprocessing, set overwrite=True\n", - "Output already exists for rade-features\n", - "To rerun feature extraction, set overwrite=True\n" - ] - } - ], + "outputs": [], "source": [ - "base_dir = Path(\"/workspace/fieldwork-data/\")\n", - "session_dir = base_dir / \"birds/2024-02-06/SplatsSD\"\n", - "\n", - "# Make the configuration\n", - "splatter_config = SplatterConfig(\n", - " file_path=session_dir / \"C0043.MP4\",\n", - " method=\"rade-features\",\n", - " frame_proportion=0.25, # Use 25% of the frames within the video (or default to minimum 300 frames)\n", + "# Load splatter from YAML config\n", + "splatter = Splatter.from_config_file(\n", + " dataset='birds_date-02062024_video-C0043',\n", + " config_dir='configs'\n", ")\n", "\n", - "# Initialize the Splatter class\n", - "splatter = Splatter(splatter_config)\n", - "\n", - "# Call these to populate the splatter with paths (probably a better way to do this --> maybe save out config)\n", + "# Ensure preprocessing and training are done\n", + "# (if already completed, these will skip automatically)\n", "splatter.preprocess()\n", "splatter.extract_features()" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "### Create a mesh\n", - "\n", - "We can create a mesh by calling the ```mesh()``` method. Under the hood, this runs TSDF fusion creating an integrated volume. " + "# Create the mesh (config already contains meshing parameters)\n", + "splatter.mesh(overwrite=False)" ] }, { - "cell_type": "code", - "execution_count": 3, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Available runs:\n", - "[0] 2025-07-25_040743\n" - ] - } - ], "source": [ - "mesher_kwargs = {\n", - " \"depth_name\": \"median_depth\",\n", - " \"depth_trunc\": 1.0, # Should be between 1.0 and 3.0\n", - " \"voxel_size\": 0.005,\n", - " \"normals_name\": \"normals\",\n", - " \"features_name\": \"distill_features\",\n", - " \"sdf_trunc\": 0.03,\n", - " \"clean_repair\": True,\n", - " \"align_floor\": True,\n", - "}\n", + "### Plot the mesh!\n", "\n", - "splatter.mesh(\n", - " mesher_type=\"Open3DTSDFFusion\",\n", - " mesher_kwargs=mesher_kwargs,\n", - " # overwrite=True\n", - ")" + "We can use the splatter function ```plot_mesh``` to visualize given attributes of the mesh. The inherent attributes are RGB and Normals" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Plot the mesh!\n", + "### Using semantic queries \n", "\n", - "We can use the splatter function ```plot_mesh``` to visualize given attributes of the mesh. The inherent attributes are RGB and Normals" + "The mesh contains semantic features which we can query via positive and negative prompts. The goal of this is to find points that are more similar to the positive prompts compared to the negative prompts" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of points: 220146\n", - "Number of cells: 427851\n", - "Bounds: BoundsTuple(x_min=-0.9965442419052124, x_max=0.6880755424499512, y_min=-0.26933255791664124, y_max=1.3843693733215332, z_min=-0.3129904866218567, z_max=0.5518640279769897)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1dcc1225aeca4548bb0cd4639f36aa79", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Widget(value='