diff --git a/src/dolibarr_mcp/dolibarr_mcp_server.py b/src/dolibarr_mcp/dolibarr_mcp_server.py index ec6badd..e992d11 100644 --- a/src/dolibarr_mcp/dolibarr_mcp_server.py +++ b/src/dolibarr_mcp/dolibarr_mcp_server.py @@ -29,6 +29,11 @@ logging.basicConfig( server = Server("dolibarr-mcp") +def _escape_sqlfilter(value: str) -> str: + """Escape single quotes for SQL filters.""" + return value.replace("'", "''") + + @server.list_tools() async def handle_list_tools(): """List all available tools.""" @@ -569,26 +574,27 @@ async def handle_call_tool(name: str, arguments: dict): # Search Tools elif name == "search_products_by_ref": - ref_prefix = arguments['ref_prefix'] + ref_prefix = _escape_sqlfilter(arguments['ref_prefix']) limit = arguments.get('limit', 20) sqlfilters = f"(t.ref:like:'{ref_prefix}%')" result = await client.search_products(sqlfilters=sqlfilters, limit=limit) elif name == "search_customers": - query = arguments['query'] + query = _escape_sqlfilter(arguments['query']) limit = arguments.get('limit', 20) sqlfilters = f"((t.nom:like:'%{query}%') OR (t.name_alias:like:'%{query}%'))" result = await client.search_customers(sqlfilters=sqlfilters, limit=limit) elif name == "search_products_by_label": - label_search = arguments['label_search'] + label_search = _escape_sqlfilter(arguments['label_search']) limit = arguments.get('limit', 20) sqlfilters = f"(t.label:like:'%{label_search}%')" result = await client.search_products(sqlfilters=sqlfilters, limit=limit) elif name == "resolve_product_ref": ref = arguments['ref'] - sqlfilters = f"(t.ref:like:'{ref}')" + ref_esc = _escape_sqlfilter(ref) + sqlfilters = f"(t.ref:like:'{ref_esc}')" products = await client.search_products(sqlfilters=sqlfilters, limit=2) if not products: @@ -599,7 +605,7 @@ async def handle_call_tool(name: str, arguments: dict): # Check if one is exact match exact_matches = [p for p in products if p.get('ref') == ref] if len(exact_matches) == 1: - result = {"status": "ok", "product": exact_matches[0]} + result = {"status": "ok", "product": exact_matches[0]} else: result = {"status": "ambiguous", "message": f"Multiple products found for ref '{ref}'", "products": products} diff --git a/tests/test_search_tools.py b/tests/test_search_tools.py new file mode 100644 index 0000000..6f86f46 --- /dev/null +++ b/tests/test_search_tools.py @@ -0,0 +1,79 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from dolibarr_mcp.dolibarr_mcp_server import handle_call_tool +from dolibarr_mcp.dolibarr_client import DolibarrClient + +@pytest.mark.asyncio +async def test_search_products_by_ref(): + # Mock DolibarrClient + with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient: + mock_instance = MockClient.return_value + mock_instance.__aenter__.return_value = mock_instance + + # Mock search_products response + mock_instance.search_products = AsyncMock(return_value=[ + {"id": 1, "ref": "PRJ-123", "label": "Project 123"} + ]) + + # Call the tool + result = await handle_call_tool("search_products_by_ref", {"ref_prefix": "PRJ"}) + + # Verify the call + mock_instance.search_products.assert_called_once() + call_args = mock_instance.search_products.call_args + assert "sqlfilters" in call_args.kwargs + assert call_args.kwargs["sqlfilters"] == "(t.ref:like:'PRJ%')" + + # Verify result + assert "PRJ-123" in result[0].text + +@pytest.mark.asyncio +async def test_resolve_product_ref_exact(): + with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient: + mock_instance = MockClient.return_value + mock_instance.__aenter__.return_value = mock_instance + + # Mock search_products response (exact match) + mock_instance.search_products = AsyncMock(return_value=[ + {"id": 1, "ref": "PRJ-123", "label": "Project 123"} + ]) + + result = await handle_call_tool("resolve_product_ref", {"ref": "PRJ-123"}) + + mock_instance.search_products.assert_called_once() + assert "ok" in result[0].text + assert "PRJ-123" in result[0].text + +@pytest.mark.asyncio +async def test_resolve_product_ref_ambiguous(): + with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient: + mock_instance = MockClient.return_value + mock_instance.__aenter__.return_value = mock_instance + + # Mock search_products response (multiple matches, none exact) + mock_instance.search_products = AsyncMock(return_value=[ + {"id": 1, "ref": "PRJ-123-A", "label": "Project 123 A"}, + {"id": 2, "ref": "PRJ-123-B", "label": "Project 123 B"} + ]) + + # Search for "PRJ-123" which matches both partially (hypothetically) but neither exactly + result = await handle_call_tool("resolve_product_ref", {"ref": "PRJ-123"}) + + assert "ambiguous" in result[0].text + +@pytest.mark.asyncio +async def test_search_customers(): + with patch("dolibarr_mcp.dolibarr_mcp_server.DolibarrClient") as MockClient: + mock_instance = MockClient.return_value + mock_instance.__aenter__.return_value = mock_instance + + mock_instance.search_customers = AsyncMock(return_value=[ + {"id": 1, "nom": "Acme Corp"} + ]) + + result = await handle_call_tool("search_customers", {"query": "Acme"}) + + mock_instance.search_customers.assert_called_once() + call_args = mock_instance.search_customers.call_args + assert "sqlfilters" in call_args.kwargs + assert "Acme" in call_args.kwargs["sqlfilters"]