fix(server): improve search tools with escaping and fix indentation (Review Feedback)

This commit is contained in:
Benju1
2025-12-05 03:18:10 +01:00
parent 987424f6e3
commit 112fb772bd
2 changed files with 90 additions and 5 deletions

View File

@@ -29,6 +29,11 @@ logging.basicConfig(
server = Server("dolibarr-mcp")
def _escape_sqlfilter(value: str) -> str:
"""Escape single quotes for SQL filters."""
return value.replace("'", "''")
@server.list_tools()
async def handle_list_tools():
"""List all available tools."""
@@ -569,26 +574,27 @@ async def handle_call_tool(name: str, arguments: dict):
# Search Tools
elif name == "search_products_by_ref":
ref_prefix = arguments['ref_prefix']
ref_prefix = _escape_sqlfilter(arguments['ref_prefix'])
limit = arguments.get('limit', 20)
sqlfilters = f"(t.ref:like:'{ref_prefix}%')"
result = await client.search_products(sqlfilters=sqlfilters, limit=limit)
elif name == "search_customers":
query = arguments['query']
query = _escape_sqlfilter(arguments['query'])
limit = arguments.get('limit', 20)
sqlfilters = f"((t.nom:like:'%{query}%') OR (t.name_alias:like:'%{query}%'))"
result = await client.search_customers(sqlfilters=sqlfilters, limit=limit)
elif name == "search_products_by_label":
label_search = arguments['label_search']
label_search = _escape_sqlfilter(arguments['label_search'])
limit = arguments.get('limit', 20)
sqlfilters = f"(t.label:like:'%{label_search}%')"
result = await client.search_products(sqlfilters=sqlfilters, limit=limit)
elif name == "resolve_product_ref":
ref = arguments['ref']
sqlfilters = f"(t.ref:like:'{ref}')"
ref_esc = _escape_sqlfilter(ref)
sqlfilters = f"(t.ref:like:'{ref_esc}')"
products = await client.search_products(sqlfilters=sqlfilters, limit=2)
if not products:
@@ -599,7 +605,7 @@ async def handle_call_tool(name: str, arguments: dict):
# Check if one is exact match
exact_matches = [p for p in products if p.get('ref') == ref]
if len(exact_matches) == 1:
result = {"status": "ok", "product": exact_matches[0]}
result = {"status": "ok", "product": exact_matches[0]}
else:
result = {"status": "ambiguous", "message": f"Multiple products found for ref '{ref}'", "products": products}

View File

@@ -0,0 +1,79 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dolibarr_mcp.dolibarr_mcp_server import handle_call_tool
from dolibarr_mcp.dolibarr_client import DolibarrClient
@pytest.mark.asyncio
async def test_search_products_by_ref():
# Mock DolibarrClient
with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient:
mock_instance = MockClient.return_value
mock_instance.__aenter__.return_value = mock_instance
# Mock search_products response
mock_instance.search_products = AsyncMock(return_value=[
{"id": 1, "ref": "PRJ-123", "label": "Project 123"}
])
# Call the tool
result = await handle_call_tool("search_products_by_ref", {"ref_prefix": "PRJ"})
# Verify the call
mock_instance.search_products.assert_called_once()
call_args = mock_instance.search_products.call_args
assert "sqlfilters" in call_args.kwargs
assert call_args.kwargs["sqlfilters"] == "(t.ref:like:'PRJ%')"
# Verify result
assert "PRJ-123" in result[0].text
@pytest.mark.asyncio
async def test_resolve_product_ref_exact():
with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient:
mock_instance = MockClient.return_value
mock_instance.__aenter__.return_value = mock_instance
# Mock search_products response (exact match)
mock_instance.search_products = AsyncMock(return_value=[
{"id": 1, "ref": "PRJ-123", "label": "Project 123"}
])
result = await handle_call_tool("resolve_product_ref", {"ref": "PRJ-123"})
mock_instance.search_products.assert_called_once()
assert "ok" in result[0].text
assert "PRJ-123" in result[0].text
@pytest.mark.asyncio
async def test_resolve_product_ref_ambiguous():
with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient:
mock_instance = MockClient.return_value
mock_instance.__aenter__.return_value = mock_instance
# Mock search_products response (multiple matches, none exact)
mock_instance.search_products = AsyncMock(return_value=[
{"id": 1, "ref": "PRJ-123-A", "label": "Project 123 A"},
{"id": 2, "ref": "PRJ-123-B", "label": "Project 123 B"}
])
# Search for "PRJ-123" which matches both partially (hypothetically) but neither exactly
result = await handle_call_tool("resolve_product_ref", {"ref": "PRJ-123"})
assert "ambiguous" in result[0].text
@pytest.mark.asyncio
async def test_search_customers():
with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient:
mock_instance = MockClient.return_value
mock_instance.__aenter__.return_value = mock_instance
mock_instance.search_customers = AsyncMock(return_value=[
{"id": 1, "nom": "Acme Corp"}
])
result = await handle_call_tool("search_customers", {"query": "Acme"})
mock_instance.search_customers.assert_called_once()
call_args = mock_instance.search_customers.call_args
assert "sqlfilters" in call_args.kwargs
assert "Acme" in call_args.kwargs["sqlfilters"]