diff --git a/application/tests/test_cheatsheet_categorizer.py b/application/tests/test_cheatsheet_categorizer.py new file mode 100644 index 000000000..533c7fbb2 --- /dev/null +++ b/application/tests/test_cheatsheet_categorizer.py @@ -0,0 +1,527 @@ +""" +Tests for Workstream C: cheatsheet_categorizer +=============================================== + +Covers: + - TAXONOMY integrity + - categorize_cheatsheet -- deterministic path + - categorize_cheatsheet -- LLM path (success, bad labels, exception) + - categorize_cheatsheet -- UNCATEGORIZED fallback + - group_cheatsheets -- grouping, stable IDs, ordering + - CheatsheetGroup.make_group_id -- determinism + - _validate_labels helper +""" + +import unittest + +from application.utils.external_project_parsers.parsers.cheatsheet_categorizer import ( # noqa: E501 + TAXONOMY, + UNCATEGORIZED, + CheatsheetGroup, + CheatsheetRecord, + categorize_cheatsheet, + group_cheatsheets, + _validate_labels, + _deterministic_categorize, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_record( + source_id: str, + title: str, + headings=None, + category_hints=None, +) -> CheatsheetRecord: + """Return a minimal CheatsheetRecord for use in tests.""" + base = "https://cheatsheetseries.owasp.org/cheatsheets" + return CheatsheetRecord( + source="owasp_cheatsheets", + source_id=source_id, + title=title, + hyperlink=f"{base}/{source_id}.html", + summary="Test summary", + headings=headings or [], + raw_markdown_path=f"cheatsheets/{source_id}.md", + category_hints=category_hints or [], + ) + + +# --------------------------------------------------------------------------- +# 1. TAXONOMY integrity +# --------------------------------------------------------------------------- + + +class TestTaxonomy(unittest.TestCase): + """Verify the controlled taxonomy list stays well-formed.""" + + def test_uncategorized_in_taxonomy(self): + """UNCATEGORIZED sentinel must be present in TAXONOMY.""" + self.assertIn(UNCATEGORIZED, TAXONOMY) + + def test_no_duplicates(self): + """Every label must appear exactly once.""" + self.assertEqual(len(TAXONOMY), len(set(TAXONOMY))) + + def test_all_lowercase(self): + """Labels must be lowercase to allow case-insensitive matching.""" + for label in TAXONOMY: + self.assertEqual(label, label.lower(), f"Label not lowercase: {label!r}") + + def test_minimum_size(self): + """Taxonomy must have at least 10 real labels plus UNCATEGORIZED.""" + self.assertGreater(len(TAXONOMY), 10) + + +# --------------------------------------------------------------------------- +# 2. categorize_cheatsheet -- deterministic, known categories +# --------------------------------------------------------------------------- + + +class TestCategorizeDeterministic(unittest.TestCase): + """Deterministic keyword-based categorisation.""" + + def _cats(self, record): + """Shorthand for categorize_cheatsheet.""" + return categorize_cheatsheet(record) + + # --- authentication --- + def test_authentication_by_title(self): + r = _make_record("Authentication_Cheat_Sheet", "Authentication Cheat Sheet") + self.assertIn("authentication", self._cats(r)) + + def test_password_implies_authentication(self): + r = _make_record("Password_Storage_Cheat_Sheet", "Password Storage Cheat Sheet") + self.assertIn("authentication", self._cats(r)) + + def test_oauth_implies_authentication(self): + r = _make_record( + "OAuth_Cheat_Sheet", + "OAuth 2.0 Cheat Sheet", + headings=["Authorization Code Flow"], + ) + self.assertIn("authentication", self._cats(r)) + + # --- secrets management --- + def test_secrets_management(self): + r = _make_record( + "Secrets_Management_Cheat_Sheet", + "Secrets Management Cheat Sheet", + headings=["Introduction", "Secret Rotation", "Operational Practices"], + ) + self.assertIn("secrets-management", self._cats(r)) + + def test_secrets_and_operations_both_match(self): + r = _make_record( + "Secrets_Management_Cheat_Sheet", + "Secrets Management Cheat Sheet", + headings=["Operational Practices", "Secret Rotation"], + ) + result = self._cats(r) + self.assertIn("secrets-management", result) + self.assertIn("operations", result) + + # --- cryptography --- + def test_cryptography_by_title(self): + r = _make_record( + "Cryptographic_Storage_Cheat_Sheet", + "Cryptographic Storage Cheat Sheet", + ) + self.assertIn("cryptography", self._cats(r)) + + def test_tls_implies_cryptography(self): + r = _make_record("TLS_Cheat_Sheet", "TLS Cheat Sheet") + self.assertIn("cryptography", self._cats(r)) + + # --- injection --- + def test_sql_injection(self): + r = _make_record("SQL_Injection_Prevention", "SQL Injection Prevention") + self.assertIn("injection", self._cats(r)) + + # --- logging --- + def test_logging(self): + r = _make_record("Logging_Cheat_Sheet", "Logging Cheat Sheet") + self.assertIn("logging-and-monitoring", self._cats(r)) + + # --- api security --- + def test_api_security(self): + r = _make_record( + "REST_Security_Cheat_Sheet", + "REST Security Cheat Sheet", + headings=["API Security Overview"], + ) + result = self._cats(r) + self.assertIn( + "api-security", + result, + f"Expected api-security in {result}", + ) + + # --- output encoding / xss --- + def test_xss_implies_output_encoding(self): + r = _make_record("XSS_Prevention_Cheat_Sheet", "XSS Prevention Cheat Sheet") + self.assertIn("output-encoding", self._cats(r)) + + # --- container security --- + def test_docker_implies_container(self): + r = _make_record("Docker_Security_Cheat_Sheet", "Docker Security Cheat Sheet") + self.assertIn("container-security", self._cats(r)) + + # --- category hints contribute --- + def test_category_hints_used(self): + r = _make_record( + "Misc_Cheat_Sheet", + "Miscellaneous Cheat Sheet", + category_hints=["cloud"], + ) + self.assertIn("cloud-security", self._cats(r)) + + # --- output properties --- + def test_output_is_sorted(self): + r = _make_record( + "Auth_Session", + "Authentication Session Management", + headings=["Session Tokens", "Password Policy"], + ) + result = self._cats(r) + self.assertEqual(result, sorted(result)) + + def test_output_no_duplicates(self): + r = _make_record("Auth_Auth", "Authentication Authentication") + result = self._cats(r) + self.assertEqual(len(result), len(set(result))) + + def test_labels_all_in_taxonomy(self): + r = _make_record( + "Secrets_Management_Cheat_Sheet", + "Secrets Management Cheat Sheet", + headings=["Secret Rotation", "Logging Practices", "Encryption"], + ) + for label in categorize_cheatsheet(r): + self.assertIn(label, TAXONOMY, f"Label {label!r} not in TAXONOMY") + + def test_determinism_same_input_same_output(self): + r = _make_record( + "Secrets_Management_Cheat_Sheet", + "Secrets Management Cheat Sheet", + headings=["Secret Rotation", "Operational Practices"], + ) + self.assertEqual(categorize_cheatsheet(r), categorize_cheatsheet(r)) + + +# --------------------------------------------------------------------------- +# 3. categorize_cheatsheet -- UNCATEGORIZED fallback +# --------------------------------------------------------------------------- + + +class TestUncategorizedFallback(unittest.TestCase): + """Unknown inputs must map to UNCATEGORIZED without raising.""" + + def test_empty_record_returns_uncategorized(self): + r = _make_record("Unknown_Cheat_Sheet", "Unknown Topic") + self.assertEqual(categorize_cheatsheet(r), [UNCATEGORIZED]) + + def test_uncategorized_not_mixed_with_real_labels(self): + """If any real label matches, UNCATEGORIZED must NOT appear.""" + r = _make_record("Auth_Cheat_Sheet", "Authentication Cheat Sheet") + self.assertNotIn(UNCATEGORIZED, categorize_cheatsheet(r)) + + +# --------------------------------------------------------------------------- +# 4. categorize_cheatsheet -- LLM path +# --------------------------------------------------------------------------- + + +class TestCategorizeLLMPath(unittest.TestCase): + """LLM integration: success, fallback on bad output, fallback on error.""" + + def _secrets_record(self): + """Reusable secrets management record.""" + return _make_record( + "Secrets_Management_Cheat_Sheet", "Secrets Management Cheat Sheet" + ) + + def test_llm_success_returns_llm_labels(self): + def good_llm(record): + return ["secrets-management", "operations"] + + result = categorize_cheatsheet( + self._secrets_record(), + use_llm=True, + llm_categorize_fn=good_llm, + ) + self.assertIn("secrets-management", result) + self.assertIn("operations", result) + + def test_llm_bad_labels_falls_back_to_deterministic(self): + """LLM returns labels not in TAXONOMY -- fall back.""" + + def bad_llm(record): + return ["not-a-real-label", "also-fake"] + + result = categorize_cheatsheet( + self._secrets_record(), + use_llm=True, + llm_categorize_fn=bad_llm, + ) + for label in result: + self.assertIn(label, TAXONOMY) + self.assertIn("secrets-management", result) + + def test_llm_exception_falls_back_to_deterministic(self): + """LLM raises an exception -- fall back gracefully.""" + + def crashing_llm(record): + raise RuntimeError("API timeout") + + result = categorize_cheatsheet( + self._secrets_record(), + use_llm=True, + llm_categorize_fn=crashing_llm, + ) + self.assertIn("secrets-management", result) + + def test_llm_returns_empty_list_falls_back(self): + """Empty LLM response triggers deterministic fallback.""" + + def empty_llm(record): + return [] + + result = categorize_cheatsheet( + self._secrets_record(), + use_llm=True, + llm_categorize_fn=empty_llm, + ) + self.assertIn("secrets-management", result) + + def test_llm_returns_non_list_falls_back(self): + """Non-list LLM response triggers deterministic fallback.""" + + def bad_type_llm(record): + return "secrets-management" # string, not list + + result = categorize_cheatsheet( + self._secrets_record(), + use_llm=True, + llm_categorize_fn=bad_type_llm, + ) + for label in result: + self.assertIn(label, TAXONOMY) + + def test_use_llm_false_ignores_llm_fn(self): + """use_llm=False must not call llm_categorize_fn at all.""" + call_count = {"n": 0} + + def tracking_llm(record): + call_count["n"] += 1 + return ["authentication"] + + categorize_cheatsheet( + self._secrets_record(), + use_llm=False, + llm_categorize_fn=tracking_llm, + ) + self.assertEqual(call_count["n"], 0) + + +# --------------------------------------------------------------------------- +# 5. group_cheatsheets +# --------------------------------------------------------------------------- + + +class TestGroupCheatsheets(unittest.TestCase): + """Grouping behaviour, stable IDs, and membership completeness.""" + + def setUp(self): + self.auth_record = _make_record( + "Authentication_Cheat_Sheet", "Authentication Cheat Sheet" + ) + self.password_record = _make_record( + "Password_Storage_Cheat_Sheet", "Password Storage Cheat Sheet" + ) + self.secrets_record = _make_record( + "Secrets_Management_Cheat_Sheet", + "Secrets Management Cheat Sheet", + headings=["Secret Rotation", "Operational Practices"], + ) + self.unknown_record = _make_record("Unknown_Topic_Cheat_Sheet", "Unknown Topic") + + def test_same_category_same_group(self): + """Records with identical label sets must land in the same group.""" + auth_labels = categorize_cheatsheet(self.auth_record) + pwd_labels = categorize_cheatsheet(self.password_record) + self.assertEqual( + auth_labels, + pwd_labels, + "auth and password records must share the same labels", + ) + groups = group_cheatsheets([self.auth_record, self.password_record]) + self.assertEqual(len(groups), 1) + self.assertEqual(len(groups[0].members), 2) + + def test_different_categories_different_groups(self): + """Records with different label sets must land in different groups.""" + auth_labels = categorize_cheatsheet(self.auth_record) + secrets_labels = categorize_cheatsheet(self.secrets_record) + self.assertNotEqual( + auth_labels, + secrets_labels, + "auth and secrets records must have different labels", + ) + groups = group_cheatsheets([self.auth_record, self.secrets_record]) + self.assertGreater(len(groups), 1) + + def test_unknown_record_in_uncategorized_group(self): + groups = group_cheatsheets([self.unknown_record]) + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0].labels, [UNCATEGORIZED]) + + def test_group_ids_are_stable(self): + """Same input twice must produce the same group_ids.""" + records = [self.auth_record, self.secrets_record, self.unknown_record] + first = {g.group_id for g in group_cheatsheets(records)} + second = {g.group_id for g in group_cheatsheets(records)} + self.assertEqual(first, second) + + def test_output_is_sorted_by_group_id(self): + """Groups must be returned in ascending group_id order.""" + records = [ + self.auth_record, + self.secrets_record, + self.unknown_record, + self.password_record, + ] + groups = group_cheatsheets(records) + ids = [g.group_id for g in groups] + self.assertEqual(ids, sorted(ids)) + + def test_empty_input_returns_empty_list(self): + self.assertEqual(group_cheatsheets([]), []) + + def test_all_members_present(self): + """Every input record must appear in exactly one group.""" + records = [self.auth_record, self.secrets_record, self.unknown_record] + groups = group_cheatsheets(records) + all_members = [m for g in groups for m in g.members] + self.assertCountEqual(all_members, records) + + def test_group_labels_all_in_taxonomy(self): + """Group labels must be drawn from TAXONOMY.""" + records = [self.auth_record, self.secrets_record, self.unknown_record] + for group in group_cheatsheets(records): + for label in group.labels: + self.assertIn(label, TAXONOMY) + + def test_single_record(self): + """A single record must produce exactly one group.""" + groups = group_cheatsheets([self.auth_record]) + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0].members[0], self.auth_record) + + +# --------------------------------------------------------------------------- +# 6. CheatsheetGroup.make_group_id +# --------------------------------------------------------------------------- + + +class TestMakeGroupId(unittest.TestCase): + """make_group_id must be deterministic and order-independent.""" + + def test_same_labels_same_id(self): + a = CheatsheetGroup.make_group_id(["authentication", "session-management"]) + b = CheatsheetGroup.make_group_id(["authentication", "session-management"]) + self.assertEqual(a, b) + + def test_order_independent(self): + a = CheatsheetGroup.make_group_id(["authentication", "session-management"]) + b = CheatsheetGroup.make_group_id(["session-management", "authentication"]) + self.assertEqual(a, b) + + def test_different_labels_different_id(self): + a = CheatsheetGroup.make_group_id(["authentication"]) + b = CheatsheetGroup.make_group_id(["cryptography"]) + self.assertNotEqual(a, b) + + def test_id_is_12_hex_chars(self): + gid = CheatsheetGroup.make_group_id(["authentication"]) + self.assertEqual(len(gid), 12) + self.assertTrue(all(c in "0123456789abcdef" for c in gid)) + + +# --------------------------------------------------------------------------- +# 7. _validate_labels +# --------------------------------------------------------------------------- + + +class TestValidateLabels(unittest.TestCase): + """_validate_labels must filter to taxonomy and deduplicate.""" + + def test_valid_labels_pass_through(self): + result = _validate_labels(["authentication", "cryptography"]) + self.assertEqual(result, ["authentication", "cryptography"]) + + def test_invalid_labels_filtered(self): + result = _validate_labels(["authentication", "not-a-real-label"]) + self.assertEqual(result, ["authentication"]) + + def test_all_invalid_returns_empty(self): + self.assertEqual(_validate_labels(["fake1", "fake2"]), []) + + def test_non_list_returns_empty(self): + self.assertEqual(_validate_labels("authentication"), []) + self.assertEqual(_validate_labels(None), []) + self.assertEqual(_validate_labels(42), []) + + def test_duplicates_removed(self): + result = _validate_labels(["authentication", "authentication", "cryptography"]) + self.assertEqual(len(result), 2) + self.assertEqual(result, ["authentication", "cryptography"]) + + def test_empty_list_returns_empty(self): + self.assertEqual(_validate_labels([]), []) + + +# --------------------------------------------------------------------------- +# 8. _deterministic_categorize (direct unit tests) +# --------------------------------------------------------------------------- + + +class TestDeterministicCategorize(unittest.TestCase): + """Direct tests of the internal keyword matcher.""" + + def test_returns_list(self): + """Function must always return a list.""" + r = _make_record("X", "Xa") + self.assertIsInstance(_deterministic_categorize(r), list) + + def test_known_categories_three_plus(self): + """Spot-check three categories to guard against regression.""" + cases = [ + ( + "Logging_Cheat_Sheet", + "Logging Cheat Sheet", + "logging-and-monitoring", + ), + ( + "XSS_Prevention", + "XSS Prevention Cheat Sheet", + "output-encoding", + ), + ( + "Docker_Security", + "Docker Security Cheat Sheet", + "container-security", + ), + ] + for source_id, title, expected_label in cases: + with self.subTest(source_id=source_id): + r = _make_record(source_id, title) + self.assertIn(expected_label, _deterministic_categorize(r)) + + +if __name__ == "__main__": + unittest.main() diff --git a/application/utils/external_project_parsers/parsers/cheatsheet_categorizer.py b/application/utils/external_project_parsers/parsers/cheatsheet_categorizer.py new file mode 100644 index 000000000..1f8374f46 --- /dev/null +++ b/application/utils/external_project_parsers/parsers/cheatsheet_categorizer.py @@ -0,0 +1,392 @@ +""" +Workstream C: Categorization and Optional Grouping +=================================================== +Provides: + - categorize_cheatsheet(record) -> list[str] + - group_cheatsheets(records) -> list[CheatsheetGroup] + +Design rules +------------ +* All labels come ONLY from TAXONOMY (a controlled vocabulary). +* Deterministic mode (default): pure keyword/rule matching, no LLM, + no randomness. +* Same input always returns the same output. +* Unknown/ambiguous inputs map to [UNCATEGORIZED] -- never raise. +* LLM path is opt-in and always has a safe deterministic fallback. +* Group IDs are stable: sha256(sorted category labels) so they survive + re-ordering of the input list. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass, field +from typing import List + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# 1. Controlled taxonomy +# --------------------------------------------------------------------------- + +#: Sentinel returned when no category matches. +UNCATEGORIZED = "uncategorized" + +#: Complete approved label set. Add new labels HERE and nowhere else. +TAXONOMY: List[str] = [ + "authentication", + "authorization", + "session-management", + "cryptography", + "secrets-management", + "input-validation", + "output-encoding", + "injection", + "api-security", + "logging-and-monitoring", + "error-handling", + "file-upload", + "xml-security", + "deserialization", + "supply-chain", + "infrastructure-security", + "network-security", + "container-security", + "cloud-security", + "microservices-security", + "access-control", + "privacy", + "threat-modeling", + "incident-response", + "vulnerability-disclosure", + "secure-coding", + "operations", + "mobile-security", + "browser-security", + UNCATEGORIZED, +] + +# Keyword -> taxonomy label map. +# Keys are lowercase substrings matched against title + headings + +# category_hints. +# Evaluated in order; first match wins per label (multiple labels allowed). +_KEYWORD_RULES: List[tuple] = [ + # secrets / key management + ("secret", "secrets-management"), + ("key management", "secrets-management"), + ("credential", "secrets-management"), + # authentication + ("authentication", "authentication"), + ("password", "authentication"), + ("multi-factor", "authentication"), + ("mfa", "authentication"), + ("saml", "authentication"), + ("oauth", "authentication"), + ("oidc", "authentication"), + ("jwt", "authentication"), + ("forgot password", "authentication"), + # authorization / access control + ("authorization", "authorization"), + ("access control", "access-control"), + ("privilege", "access-control"), + ("rbac", "access-control"), + # session + ("session", "session-management"), + # cryptography + ("cryptograph", "cryptography"), + ("encrypt", "cryptography"), + ("tls", "cryptography"), + ("hashing", "cryptography"), + ("cipher", "cryptography"), + # input validation / output encoding + ("input validation", "input-validation"), + ("sanitiz", "input-validation"), + ("output encoding", "output-encoding"), + ("xss", "output-encoding"), + ("cross-site scrip", "output-encoding"), + # injection + ("sql injection", "injection"), + ("injection", "injection"), + ("ldap injection", "injection"), + ("xxe", "xml-security"), + # api + ("api security", "api-security"), + ("graphql", "api-security"), + ("rest security", "api-security"), + # logging / monitoring + ("logging", "logging-and-monitoring"), + ("monitoring", "logging-and-monitoring"), + ("audit", "logging-and-monitoring"), + # error handling + ("error handling", "error-handling"), + ("exception", "error-handling"), + # file upload + ("file upload", "file-upload"), + # xml + ("xml", "xml-security"), + ("xpath", "xml-security"), + # deserialization + ("deserializ", "deserialization"), + # supply chain + ("dependency", "supply-chain"), + ("third-party", "supply-chain"), + ("software composition", "supply-chain"), + # infrastructure / network + ("infrastructure", "infrastructure-security"), + ("network security", "network-security"), + ("firewall", "network-security"), + # container / cloud / microservices + ("container", "container-security"), + ("docker", "container-security"), + ("kubernetes", "container-security"), + ("cloud", "cloud-security"), + ("microservice", "microservices-security"), + ("serverless", "cloud-security"), + # privacy / threat modeling / incident response + ("privacy", "privacy"), + ("gdpr", "privacy"), + ("threat model", "threat-modeling"), + ("incident response", "incident-response"), + ("disclosure", "vulnerability-disclosure"), + # mobile / browser + ("mobile", "mobile-security"), + ("android", "mobile-security"), + ("ios", "mobile-security"), + ("browser", "browser-security"), + ("cors", "browser-security"), + ("content security policy", "browser-security"), + ("csp", "browser-security"), + # operations / secure coding (broad catch-alls -- keep near the bottom) + ("operational", "operations"), + ("rotation", "operations"), + ("secure coding", "secure-coding"), + ("secure development", "secure-coding"), +] + + +# --------------------------------------------------------------------------- +# 2. CheatsheetRecord (minimal interface expected by this module) +# --------------------------------------------------------------------------- + + +@dataclass +class CheatsheetRecord: + """Typed representation of a parsed cheat sheet. + + Workstream B owns the full implementation; this definition covers + exactly the fields Workstream C needs so C can be developed and + tested independently. + + Required fields must be non-empty strings / lists after normalisation. + """ + + source: str # always "owasp_cheatsheets" + source_id: str # e.g. "Secrets_Management_Cheat_Sheet" + title: str # human-readable title + hyperlink: str # canonical cheatsheetseries URL + summary: str # bounded summary text + headings: List[str] # ordered headings from markdown + raw_markdown_path: str # path in the source repo + category_hints: List[str] = field(default_factory=list) + metadata: dict = field(default_factory=dict) + + def __post_init__(self) -> None: + """Enforce that all required fields are non-empty after construction.""" + required_str_fields = [ + "source", + "source_id", + "title", + "hyperlink", + "summary", + "raw_markdown_path", + ] + for fname in required_str_fields: + value = getattr(self, fname) + if not isinstance(value, str) or not value.strip(): + raise ValueError( + f"CheatsheetRecord.{fname} must be a non-empty string, " + f"got {value!r}" + ) + for fname in ("headings", "category_hints"): + value = getattr(self, fname) + if not isinstance(value, list) or not all( + isinstance(item, str) for item in value + ): + raise ValueError( + f"CheatsheetRecord.{fname} must be a list of strings, " + f"got {value!r}" + ) + + +# --------------------------------------------------------------------------- +# 3. CheatsheetGroup +# --------------------------------------------------------------------------- + + +@dataclass +class CheatsheetGroup: + """A stable group of cheat sheet records sharing the same category labels. + + group_id is deterministic: sha256 of the sorted, pipe-joined labels + truncated to 12 hex chars. Stays stable across repeated runs with + the same input. + """ + + group_id: str + labels: List[str] + members: List[CheatsheetRecord] = field(default_factory=list) + + @staticmethod + def make_group_id(labels: List[str]) -> str: + """Return a 12-char hex digest that uniquely identifies a label set.""" + key = "|".join(sorted(set(labels))) + return hashlib.sha256(key.encode()).hexdigest()[:12] + + +# --------------------------------------------------------------------------- +# 4. Core public functions +# --------------------------------------------------------------------------- + + +def categorize_cheatsheet( + record: CheatsheetRecord, + *, + use_llm: bool = False, + llm_categorize_fn=None, +) -> List[str]: + """Return a sorted list of taxonomy labels for *record*. + + Labels are drawn exclusively from TAXONOMY. + If no label matches, returns [UNCATEGORIZED]. + + Parameters + ---------- + record: + A CheatsheetRecord (from Workstream B or the local stub). + use_llm: + When True, attempt to call *llm_categorize_fn* first. + Falls back to deterministic categorisation on any failure. + llm_categorize_fn: + Optional callable(record) -> list[str]. Injected for testability. + Must return a subset of TAXONOMY values. + + Returns + ------- + list[str] + Ordered, deduplicated taxonomy labels. + Always contains at least [UNCATEGORIZED]. + """ + if use_llm and llm_categorize_fn is not None: + try: + llm_labels = llm_categorize_fn(record) + validated = _validate_labels(llm_labels) + if validated: + logger.debug("LLM categorization used for %s", record.source_id) + return validated + logger.warning( + "LLM returned no valid labels for %s, " "falling back to deterministic", + record.source_id, + ) + except Exception as exc: # noqa: BLE001 + logger.warning( + "LLM categorization failed for %s (%s), " + "falling back to deterministic", + record.source_id, + exc, + ) + + return _deterministic_categorize(record) + + +def group_cheatsheets( + records: List[CheatsheetRecord], + *, + use_llm: bool = False, + llm_categorize_fn=None, +) -> List[CheatsheetGroup]: + """Assign every record to a CheatsheetGroup based on its category labels. + + Group IDs are stable: same set of labels produces the same group_id + regardless of the order records appear in *records*. + + Parameters + ---------- + records: + List of CheatsheetRecord objects to group. + use_llm: + Forwarded to categorize_cheatsheet. + llm_categorize_fn: + Forwarded to categorize_cheatsheet. + + Returns + ------- + list[CheatsheetGroup] + Groups sorted by group_id for deterministic output order. + """ + bucket: dict[str, CheatsheetGroup] = {} + + for record in records: + labels = categorize_cheatsheet( + record, + use_llm=use_llm, + llm_categorize_fn=llm_categorize_fn, + ) + gid = CheatsheetGroup.make_group_id(labels) + if gid not in bucket: + bucket[gid] = CheatsheetGroup(group_id=gid, labels=sorted(labels)) + bucket[gid].members.append(record) + + return sorted(bucket.values(), key=lambda g: g.group_id) + + +# --------------------------------------------------------------------------- +# 5. Internal helpers +# --------------------------------------------------------------------------- + + +def _build_searchable_text(record: CheatsheetRecord) -> str: + """Combine title, headings, and category_hints into lowercase text.""" + parts = [record.title, *record.headings, *record.category_hints] + return " ".join(parts).lower() + + +def _deterministic_categorize(record: CheatsheetRecord) -> List[str]: + """Run pure keyword matching against a record. + + No external calls are made. Returns sorted, deduplicated labels + from TAXONOMY, or [UNCATEGORIZED] when nothing matches. + """ + text = _build_searchable_text(record) + found: List[str] = [] + seen: set = set() + + for keyword, label in _KEYWORD_RULES: + if label not in seen and keyword in text: + found.append(label) + seen.add(label) + + if not found: + return [UNCATEGORIZED] + + return sorted(found) + + +def _validate_labels(labels) -> List[str]: + """Filter an LLM-returned label list to only approved TAXONOMY entries. + + Returns an empty list if nothing valid remains; the caller should + fall back to deterministic categorisation in that case. + UNCATEGORIZED is stripped if other real labels are present. + """ + if not isinstance(labels, list): + return [] + valid = [lbl for lbl in labels if isinstance(lbl, str) and lbl in TAXONOMY] + seen: set = set() + deduped: List[str] = [] + for lbl in valid: + if lbl not in seen: + deduped.append(lbl) + seen.add(lbl) + real = [lbl for lbl in deduped if lbl != UNCATEGORIZED] + return real if real else deduped