Преглед изворни кода

Add cycle reassignment tool and activate venv

Lukas Goldschmidt пре 1 месец
родитељ
комит
90ca74d34f
2 измењених фајлова са 98 додато и 5 уклоњено
  1. 5 0
      run.sh
  2. 93 5
      virtuoso_mcp.py

+ 5 - 0
run.sh

@@ -11,6 +11,11 @@ if [[ -f .env ]]; then
   set +a
   set +a
 fi
 fi
 
 
+if [[ -f .venv/bin/activate ]]; then
+  # shellcheck source=/dev/null
+  source .venv/bin/activate
+fi
+
 LOG_DIR="logs"
 LOG_DIR="logs"
 mkdir -p "$LOG_DIR"
 mkdir -p "$LOG_DIR"
 PID_FILE="server.pid"
 PID_FILE="server.pid"

+ 93 - 5
virtuoso_mcp.py

@@ -1,13 +1,15 @@
+import json
 import logging
 import logging
 import os
 import os
 import re
 import re
+from datetime import datetime, timezone
 from importlib import import_module
 from importlib import import_module
 from pathlib import Path
 from pathlib import Path
 from typing import Any, Callable, Dict, List, Optional
 from typing import Any, Callable, Dict, List, Optional
 
 
 import requests
 import requests
 from requests.auth import HTTPDigestAuth
 from requests.auth import HTTPDigestAuth
-from fastapi import FastAPI, HTTPException
+from fastapi import FastAPI, HTTPException, Request
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 LOG_LEVEL = os.getenv("MCP_LOG_LEVEL", "INFO").upper()
 LOG_LEVEL = os.getenv("MCP_LOG_LEVEL", "INFO").upper()
@@ -27,12 +29,22 @@ SPARQL_UPDATE_TIMEOUT = float(os.getenv("SPARQL_UPDATE_TIMEOUT", 15.0))
 SPARQL_DEFAULT_LIMIT = int(os.getenv("SPARQL_DEFAULT_LIMIT", 100))
 SPARQL_DEFAULT_LIMIT = int(os.getenv("SPARQL_DEFAULT_LIMIT", 100))
 SPARQL_MAX_LIMIT = int(os.getenv("SPARQL_MAX_LIMIT", 500))
 SPARQL_MAX_LIMIT = int(os.getenv("SPARQL_MAX_LIMIT", 500))
 GRAPH_URI = os.getenv("GRAPH_URI", "http://world.eu.org/example1")
 GRAPH_URI = os.getenv("GRAPH_URI", "http://world.eu.org/example1")
+IN_CYCLE = "http://world.eu.org/cannabis-breeding#inCycle"
+CLONE_OF = "http://world.eu.org/cannabis-breeding#cloneOf"
 EXAMPLES_DIR = Path(__file__).resolve().parent / "examples"
 EXAMPLES_DIR = Path(__file__).resolve().parent / "examples"
 EXAMPLE_GRAPH = os.getenv(
 EXAMPLE_GRAPH = os.getenv(
     "EXAMPLE_GRAPH", "http://world.eu.org/cannabis-breeding#test"
     "EXAMPLE_GRAPH", "http://world.eu.org/cannabis-breeding#test"
 )
 )
 ALLOW_EXAMPLE_LOAD = os.getenv("MCP_ALLOW_EXAMPLE_LOAD", "false").lower() == "true"
 ALLOW_EXAMPLE_LOAD = os.getenv("MCP_ALLOW_EXAMPLE_LOAD", "false").lower() == "true"
 SESSION = requests.Session()
 SESSION = requests.Session()
+LOGS_DIR = Path(__file__).resolve().parent / "logs"
+LOGS_DIR.mkdir(parents=True, exist_ok=True)
+tool_logger = logging.getLogger("virtuoso_mcp.tools")
+tool_handler = logging.FileHandler(LOGS_DIR / "tool_usage.log")
+tool_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s"))
+tool_logger.addHandler(tool_handler)
+tool_logger.setLevel(logging.INFO)
+tool_logger.propagate = False
 
 
 PREFIXES = f"""
 PREFIXES = f"""
 PREFIX : <{GRAPH_URI}>
 PREFIX : <{GRAPH_URI}>
@@ -250,6 +262,23 @@ def tool_get_entities_by_type(input_data: Dict[str, Any]) -> Dict[str, Any]:
     return run_sparql(query)
     return run_sparql(query)
 
 
 
 
+def tool_cycle_plants(input_data: Dict[str, Any]) -> Dict[str, Any]:
+    cycle_uri = input_data.get("cycle_uri")
+    if not cycle_uri:
+        raise ValueError("Missing 'cycle_uri' field")
+    limit = int(input_data.get("limit", 50))
+    limit = min(max(limit, 1), SPARQL_MAX_LIMIT)
+    query = f"""
+    SELECT ?plant ?plantLabel ?parent WHERE {{
+        ?plant <{IN_CYCLE}> <{cycle_uri}> .
+        OPTIONAL {{ ?plant rdfs:label ?plantLabel }}
+        OPTIONAL {{ ?plant <{CLONE_OF}> ?parent }}
+    }}
+    LIMIT {limit}
+    """
+    return run_sparql(query)
+
+
 def tool_get_predicates_for_subject(input_data: Dict[str, Any]) -> Dict[str, Any]:
 def tool_get_predicates_for_subject(input_data: Dict[str, Any]) -> Dict[str, Any]:
     subject_uri = input_data.get("subject_uri")
     subject_uri = input_data.get("subject_uri")
     if not subject_uri:
     if not subject_uri:
@@ -565,6 +594,34 @@ def tool_batch_insert(input_data: Dict[str, Any]) -> Dict[str, Any]:
     return {**result, "query": query}
     return {**result, "query": query}
 
 
 
 
+def tool_reassign_cycle(input_data: Dict[str, Any]) -> Dict[str, Any]:
+    subject = input_data.get("subject")
+    new_cycle = input_data.get("new_cycle")
+    old_cycle = input_data.get("old_cycle")
+    graph = input_data.get("graph") or GRAPH_URI
+
+    if not subject or not new_cycle:
+        raise ValueError("Provide 'subject' and 'new_cycle' fields")
+
+    if old_cycle:
+        delete_clause = f"<{subject}> <{IN_CYCLE}> <{old_cycle}> ."
+        where_clause = delete_clause
+        update_query = f"""
+        WITH <{graph}>
+        DELETE {{ {delete_clause} }}
+        INSERT {{ <{subject}> <{IN_CYCLE}> <{new_cycle}> . }}
+        WHERE {{ {where_clause} }}
+        """
+    else:
+        update_query = f"""
+        WITH <{graph}>
+        INSERT {{ <{subject}> <{IN_CYCLE}> <{new_cycle}> . }}
+        WHERE {{ }}
+        """
+
+    return run_sparql_update(update_query)
+
+
 def tool_insert_triple(input_data: Dict[str, Any]) -> Dict[str, Any]:
 def tool_insert_triple(input_data: Dict[str, Any]) -> Dict[str, Any]:
     subject = input_data.get("subject")
     subject = input_data.get("subject")
     predicate = input_data.get("predicate")
     predicate = input_data.get("predicate")
@@ -635,6 +692,7 @@ TOOLS = {
     "list_graphs": tool_list_graphs,
     "list_graphs": tool_list_graphs,
     "search_label": tool_search_label,
     "search_label": tool_search_label,
     "get_entities_by_type": tool_get_entities_by_type,
     "get_entities_by_type": tool_get_entities_by_type,
+    "cycle_plants": tool_cycle_plants,
     "get_predicates_for_subject": tool_get_predicates_for_subject,
     "get_predicates_for_subject": tool_get_predicates_for_subject,
     "get_labels_for_subject": tool_get_labels_for_subject,
     "get_labels_for_subject": tool_get_labels_for_subject,
     "traverse_property": tool_traverse_property,
     "traverse_property": tool_traverse_property,
@@ -646,6 +704,7 @@ TOOLS = {
     "path_traverse": tool_path_traverse,
     "path_traverse": tool_path_traverse,
     "property_usage_statistics": tool_property_usage_statistics,
     "property_usage_statistics": tool_property_usage_statistics,
     "batch_insert": tool_batch_insert,
     "batch_insert": tool_batch_insert,
+    "reassign_cycle": tool_reassign_cycle,
     "insert_triple": tool_insert_triple,
     "insert_triple": tool_insert_triple,
     "load_examples": tool_load_examples,
     "load_examples": tool_load_examples,
 }
 }
@@ -657,10 +716,25 @@ def load_domain_layers(tools: Dict[str, Callable[[Dict[str, Any]], Any]]) -> Non
     if not modules:
     if not modules:
         return
         return
     for module_name in modules:
     for module_name in modules:
+        module = None
         try:
         try:
             module = import_module(module_name)
             module = import_module(module_name)
         except ImportError as exc:
         except ImportError as exc:
-            logger.warning("Domain layer '%s' could not be imported: %s", module_name, exc)
+            base = module_name.split(".", 1)[0]
+            if base != module_name:
+                try:
+                    module = import_module(base)
+                    logger.info("Falling back to base module '%s' for domain layer '%s'", base, module_name)
+                except ImportError:
+                    logger.warning(
+                        "Domain layer '%s' could not be imported and base module '%s' is missing: %s",
+                        module_name,
+                        base,
+                        exc,
+                    )
+            else:
+                logger.warning("Domain layer '%s' could not be imported: %s", module_name, exc)
+        if module is None:
             continue
             continue
         register = getattr(module, "register_layer", None)
         register = getattr(module, "register_layer", None)
         if not callable(register):
         if not callable(register):
@@ -679,6 +753,7 @@ TOOL_DOCS = {
     "list_graphs": "List up to 50 active graph URIs.",
     "list_graphs": "List up to 50 active graph URIs.",
     "search_label": "Search rdfs:label values that contain a term (case-insensitive).",
     "search_label": "Search rdfs:label values that contain a term (case-insensitive).",
     "get_entities_by_type": "List subjects of a given rdf:type.",
     "get_entities_by_type": "List subjects of a given rdf:type.",
+    "cycle_plants": "List plants (with labels/clone parents) that belong to a specific cycle.",
     "get_predicates_for_subject": "List distinct predicates used by a subject.",
     "get_predicates_for_subject": "List distinct predicates used by a subject.",
     "get_labels_for_subject": "Fetch rdfs:label values for a subject.",
     "get_labels_for_subject": "Fetch rdfs:label values for a subject.",
     "traverse_property": "Traverse a property (incoming or outgoing) for a subject and return labels/descriptions.",
     "traverse_property": "Traverse a property (incoming or outgoing) for a subject and return labels/descriptions.",
@@ -690,6 +765,7 @@ TOOL_DOCS = {
     "path_traverse": "Follow a property path (list of predicates) from a subject, returning each step's nodes.",
     "path_traverse": "Follow a property path (list of predicates) from a subject, returning each step's nodes.",
     "property_usage_statistics": "Count how often a property is used and sample subjects/objects.",
     "property_usage_statistics": "Count how often a property is used and sample subjects/objects.",
     "batch_insert": "Insert multiple triples or TTL at once with a single guarded update.",
     "batch_insert": "Insert multiple triples or TTL at once with a single guarded update.",
+    "reassign_cycle": "Move a subject to another production cycle by updating its inCycle link.",
     "insert_triple": "Insert a single triple (useful for debugging updates).",
     "insert_triple": "Insert a single triple (useful for debugging updates).",
     "load_examples": "Load Turtle examples from the local examples/ directory into a graph.",
     "load_examples": "Load Turtle examples from the local examples/ directory into a graph.",
 }
 }
@@ -698,9 +774,21 @@ TOOL_DOCS = {
 # --- MCP ENDPOINT ---
 # --- MCP ENDPOINT ---
 
 
 @app.post("/mcp")
 @app.post("/mcp")
-def handle_mcp(request: ToolRequest):
-    tool_name = request.tool
-    input_data = request.input or {}
+def handle_mcp(tool_request: ToolRequest, http_request: Request):
+    tool_name = tool_request.tool
+    input_data = tool_request.input or {}
+    client_host = http_request.client.host if http_request.client else "unknown"
+    trimmed_input = json.dumps(input_data, ensure_ascii=False, default=str)
+    if len(trimmed_input) > 1024:
+        trimmed_input = f"{trimmed_input[:1024]}…"
+    timestamp = datetime.now(timezone.utc).isoformat()
+    tool_logger.info(
+        "tool=%s client=%s time=%s input=%s",
+        tool_name,
+        client_host,
+        timestamp,
+        trimmed_input,
+    )
 
 
     if tool_name not in TOOLS:
     if tool_name not in TOOLS:
         raise HTTPException(status_code=400, detail=f"Unknown tool: {tool_name}")
         raise HTTPException(status_code=400, detail=f"Unknown tool: {tool_name}")