-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhighlight-service.py
More file actions
132 lines (101 loc) · 3.76 KB
/
highlight-service.py
File metadata and controls
132 lines (101 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Titan Memory - Semantic Highlight Sidecar Service
Runs the Zilliz semantic-highlight-bilingual-v1 model as a local HTTP API.
Usage:
python highlight-service.py [--port 8079] [--model-path ./models/semantic-highlight-bilingual-v1]
The MCP server calls this service for sentence-level semantic highlighting.
"""
import argparse
import os
import sys
import logging
from contextlib import asynccontextmanager
import torch
from transformers import AutoModel
from fastapi import FastAPI
from pydantic import BaseModel, Field
import uvicorn
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("highlight-service")
# Global model reference
_model = None
class HighlightRequest(BaseModel):
question: str = Field(..., description="Query text")
context: str = Field(..., description="Document text to highlight")
threshold: float = Field(0.5, description="Relevance threshold 0-1")
return_sentence_metrics: bool = Field(True, description="Return per-sentence scores")
class HighlightResponse(BaseModel):
highlighted_sentences: list[str]
compression_rate: float
sentence_probabilities: list[float] = []
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
global _model
model_path = app.state.model_path
logger.info(f"Loading Zilliz semantic-highlight model from: {model_path}")
_model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
# Move to GPU if available
if torch.cuda.is_available():
_model = _model.cuda()
logger.info("Model loaded on GPU")
else:
logger.info("Model loaded on CPU")
yield
# Cleanup
_model = None
logger.info("Model unloaded")
app = FastAPI(
title="Titan Semantic Highlight Service",
version="1.0.0",
lifespan=lifespan,
)
@app.post("/highlight", response_model=HighlightResponse)
async def highlight(req: HighlightRequest):
"""Score and highlight relevant sentences in context based on query."""
if _model is None:
return HighlightResponse(
highlighted_sentences=[],
compression_rate=0.0,
sentence_probabilities=[],
)
result = _model.process(
question=req.question,
context=req.context,
threshold=req.threshold,
return_sentence_metrics=req.return_sentence_metrics,
)
return HighlightResponse(
highlighted_sentences=result.get("highlighted_sentences", []),
compression_rate=result.get("compression_rate", 0.0),
sentence_probabilities=result.get("sentence_probabilities", []),
)
@app.get("/health")
async def health():
"""Health check endpoint."""
return {
"status": "ok",
"model_loaded": _model is not None,
"gpu_available": torch.cuda.is_available(),
}
def main():
parser = argparse.ArgumentParser(description="Titan Semantic Highlight Service")
parser.add_argument("--port", type=int, default=8079, help="Port to listen on")
parser.add_argument(
"--model-path",
type=str,
default=os.path.join(os.path.dirname(__file__), "models", "semantic-highlight-bilingual-v1"),
help="Path to the Zilliz model directory",
)
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind to")
args = parser.parse_args()
# Check model exists
if not os.path.exists(args.model_path):
logger.error(f"Model not found at: {args.model_path}")
logger.error("Download it first: hf download zilliz/semantic-highlight-bilingual-v1")
sys.exit(1)
app.state.model_path = args.model_path
logger.info(f"Starting Titan Highlight Service on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()