Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,27 @@ def browse_dataset_path():
)

# Load model action in sidebar
from perceptionmetrics.models.torch_detection import TorchImageDetectionModel
try:
from perceptionmetrics.models.torch_detection import TorchImageDetectionModel
_torch_available = True
except (ImportError, OSError):
_torch_available = False
import json, tempfile

if not _torch_available:
st.warning(
"⚠️ PyTorch غير متاح على هذا النظام. "
"تبويب Inference لن يعمل. "
"يُرجى تثبيت PyTorch بشكل صحيح."
)

load_model_btn = st.button(
"Load Model",
type="primary",
width="stretch",
help="Load and save the model for use in the Inference tab",
key="sidebar_load_model_btn",
disabled=not _torch_available,
)

if load_model_btn:
Expand Down
8 changes: 4 additions & 4 deletions perceptionmetrics/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@

REGISTRY["torch_image_segmentation"] = TorchImageSegmentationModel
REGISTRY["torch_lidar_segmentation"] = TorchLiDARSegmentationModel
except ImportError:
except (ImportError, OSError):
print("Torch not available")

try:
from perceptionmetrics.models.torch_detection import TorchImageDetectionModel

REGISTRY["torch_image_detection"] = TorchImageDetectionModel
except ImportError:
except (ImportError, OSError):
print("Torch detection not available")

try:
from perceptionmetrics.models.tf_segmentation import TensorflowImageSegmentationModel

REGISTRY["tensorflow_image_segmentation"] = TensorflowImageSegmentationModel
except ImportError:
except (ImportError, OSError):
print("Tensorflow not available")

if not REGISTRY:
raise Exception("No valid deep learning framework found")
raise ImportError("No valid deep learning framework found")
3 changes: 2 additions & 1 deletion tabs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import streamlit as st
import json
from PIL import Image
import torch


def draw_detections(image: Image, predictions: dict, label_map: Optional[dict] = None):
Expand All @@ -18,6 +17,7 @@ def draw_detections(image: Image, predictions: dict, label_map: Optional[dict] =
:return: np.ndarray with detections drawn (for st.image)
:rtype: np.ndarray
"""
import torch
from perceptionmetrics.utils import image as ui

boxes = predictions.get("boxes", torch.empty(0)).cpu().numpy()
Expand Down Expand Up @@ -103,6 +103,7 @@ def inference_tab():
st.markdown("#### Detection Details")

# Convert predictions to JSON format
import torch
detection_results = []
boxes = predictions.get("boxes", torch.empty(0)).cpu().numpy()
labels = predictions.get("labels", torch.empty(0)).cpu().numpy()
Expand Down