feat: Add intelligent auto-router and enhanced integrations
- Add intelligent-router.sh hook for automatic agent routing - Add AUTO-TRIGGER-SUMMARY.md documentation - Add FINAL-INTEGRATION-SUMMARY.md documentation - Complete Prometheus integration (6 commands + 4 tools) - Complete Dexto integration (12 commands + 5 tools) - Enhanced Ralph with access to all agents - Fix /clawd command (removed disable-model-invocation) - Update hooks.json to v5 with intelligent routing - 291 total skills now available - All 21 commands with automatic routing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
0
prometheus/tests/utils/__init__.py
Normal file
0
prometheus/tests/utils/__init__.py
Normal file
35
prometheus/tests/utils/test_file_utils.py
Normal file
35
prometheus/tests/utils/test_file_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.exceptions.file_operation_exception import FileOperationException
|
||||
from prometheus.utils.file_utils import (
|
||||
read_file_with_line_numbers,
|
||||
)
|
||||
from tests.test_utils.test_project_paths import TEST_PROJECT_PATH
|
||||
|
||||
|
||||
def test_read_file_with_line_numbers():
|
||||
"""Test reading specific line ranges from a file."""
|
||||
# Test reading specific lines
|
||||
result = read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 1, 15)
|
||||
expected = """1. # A
|
||||
2.
|
||||
3. Text under header A.
|
||||
4.
|
||||
5. ## B
|
||||
6.
|
||||
7. Text under header B.
|
||||
8.
|
||||
9. ## C
|
||||
10.
|
||||
11. Text under header C.
|
||||
12.
|
||||
13. ### D
|
||||
14.
|
||||
15. Text under header D."""
|
||||
assert result == expected
|
||||
|
||||
# Test invalid range should raise exception
|
||||
with pytest.raises(FileOperationException) as exc_info:
|
||||
read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 4, 2)
|
||||
|
||||
assert str(exc_info.value) == "The end line number must be greater than the start line number."
|
||||
50
prometheus/tests/utils/test_issue_util.py
Normal file
50
prometheus/tests/utils/test_issue_util.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from prometheus.utils.issue_util import (
|
||||
format_issue_comments,
|
||||
format_issue_info,
|
||||
format_test_commands,
|
||||
)
|
||||
|
||||
|
||||
def test_format_issue_comments():
|
||||
comments = [
|
||||
{"username": "alice", "comment": "This looks good!"},
|
||||
{"username": "bob", "comment": "Can we add tests?"},
|
||||
]
|
||||
|
||||
result = format_issue_comments(comments)
|
||||
expected = "alice: This looks good!\n\nbob: Can we add tests?"
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_format_issue_info():
|
||||
title = "Bug in login flow"
|
||||
body = "Users can't login on mobile devices"
|
||||
comments = [
|
||||
{"username": "alice", "comment": "I can reproduce this"},
|
||||
{"username": "bob", "comment": "Working on a fix"},
|
||||
]
|
||||
|
||||
result = format_issue_info(title, body, comments)
|
||||
expected = """\
|
||||
Issue title:
|
||||
Bug in login flow
|
||||
|
||||
Issue description:
|
||||
Users can't login on mobile devices
|
||||
|
||||
Issue comments:
|
||||
alice: I can reproduce this
|
||||
|
||||
bob: Working on a fix"""
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_format_test_commands():
|
||||
commands = ["pytest test_login.py", "pytest test_auth.py -v"]
|
||||
|
||||
result = format_test_commands(commands)
|
||||
expected = "$ pytest test_login.py\n$ pytest test_auth.py -v"
|
||||
|
||||
assert result == expected
|
||||
260
prometheus/tests/utils/test_knowledge_graph_utils.py
Normal file
260
prometheus/tests/utils/test_knowledge_graph_utils.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from prometheus.models.context import Context
|
||||
from prometheus.utils.knowledge_graph_utils import (
|
||||
knowledge_graph_data_for_context_generator,
|
||||
sort_contexts,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_data():
|
||||
"""Test with None or empty data"""
|
||||
assert knowledge_graph_data_for_context_generator(None) == []
|
||||
assert knowledge_graph_data_for_context_generator([]) == []
|
||||
|
||||
|
||||
def test_skip_empty_content():
|
||||
"""Test that empty or whitespace-only content is skipped"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": " \n\t ", "start_line": 2, "end_line": 2},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "valid content", "start_line": 3, "end_line": 3},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "valid content"
|
||||
|
||||
|
||||
def test_deduplication_identical_content():
|
||||
"""Test deduplication of identical content"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "def hello():"
|
||||
|
||||
|
||||
def test_deduplication_content_containment():
|
||||
"""Test deduplication when one content contains another"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "def hello():\n print('world')", "start_line": 1, "end_line": 2},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": "print('world')", "start_line": 2, "end_line": 2},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "def hello():\n print('world')"
|
||||
|
||||
|
||||
def test_deduplication_line_containment():
|
||||
"""Test deduplication based on line number containment"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "function body", "start_line": 1, "end_line": 5},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": "inner content", "start_line": 2, "end_line": 3},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "function body"
|
||||
assert result[0].start_line_number == 1
|
||||
assert result[0].end_line_number == 5
|
||||
|
||||
|
||||
def test_different_files_no_deduplication():
|
||||
"""Test that identical content in different files is not deduplicated"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "file1.py"},
|
||||
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "file2.py"},
|
||||
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 2
|
||||
assert result[0].relative_path == "file1.py"
|
||||
assert result[1].relative_path == "file2.py"
|
||||
|
||||
|
||||
def test_content_stripping():
|
||||
"""Test that content is properly stripped of whitespace"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": " \n def hello(): \n ", "start_line": 1, "end_line": 1},
|
||||
}
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "def hello():"
|
||||
|
||||
|
||||
def test_complex_deduplication_scenario():
|
||||
"""Test complex scenario with multiple overlapping contexts"""
|
||||
data = [
|
||||
# Large context containing everything
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {
|
||||
"text": "class MyClass:\n def method1(self):\n pass\n def method2(self):\n pass",
|
||||
"start_line": 1,
|
||||
"end_line": 5,
|
||||
},
|
||||
},
|
||||
# Smaller context contained within the large one
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {
|
||||
"text": "def method1(self):\n pass",
|
||||
"start_line": 2,
|
||||
"end_line": 3,
|
||||
},
|
||||
},
|
||||
# Separate context in same file
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"SelectedLines": {"text": "# Comment at end", "start_line": 10, "end_line": 10},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 2 # Large context + separate comment
|
||||
assert "class MyClass:" in result[0].content
|
||||
assert result[1].content == "# Comment at end"
|
||||
|
||||
|
||||
def test_sort_contexts_empty_list():
|
||||
"""Test sorting an empty list"""
|
||||
result = sort_contexts([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_sort_contexts_by_relative_path():
|
||||
"""Test sorting by relative path"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="src/z.py", content="content z", start_line_number=1, end_line_number=5
|
||||
),
|
||||
Context(
|
||||
relative_path="src/a.py", content="content a", start_line_number=1, end_line_number=5
|
||||
),
|
||||
Context(
|
||||
relative_path="src/m.py", content="content m", start_line_number=1, end_line_number=5
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].relative_path == "src/a.py"
|
||||
assert result[1].relative_path == "src/m.py"
|
||||
assert result[2].relative_path == "src/z.py"
|
||||
|
||||
|
||||
def test_sort_contexts_by_line_numbers():
|
||||
"""Test sorting by line numbers within same file"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="test.py", content="content 3", start_line_number=20, end_line_number=25
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 1", start_line_number=1, end_line_number=5
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 2", start_line_number=10, end_line_number=15
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].start_line_number == 1
|
||||
assert result[1].start_line_number == 10
|
||||
assert result[2].start_line_number == 20
|
||||
|
||||
|
||||
def test_sort_contexts_none_line_numbers():
|
||||
"""Test sorting when line numbers are None (should appear last)"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="test.py",
|
||||
content="content with lines",
|
||||
start_line_number=10,
|
||||
end_line_number=15,
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py",
|
||||
content="content no lines",
|
||||
start_line_number=None,
|
||||
end_line_number=None,
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content first", start_line_number=1, end_line_number=5
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].start_line_number == 1
|
||||
assert result[1].start_line_number == 10
|
||||
assert result[2].start_line_number is None
|
||||
|
||||
|
||||
def test_sort_contexts_mixed_files_and_lines():
|
||||
"""Test sorting with multiple files and different line numbers"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="b.py", content="b content 2", start_line_number=20, end_line_number=25
|
||||
),
|
||||
Context(
|
||||
relative_path="a.py", content="a content 2", start_line_number=10, end_line_number=15
|
||||
),
|
||||
Context(
|
||||
relative_path="b.py", content="b content 1", start_line_number=5, end_line_number=10
|
||||
),
|
||||
Context(
|
||||
relative_path="a.py", content="a content 1", start_line_number=1, end_line_number=5
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].relative_path == "a.py" and result[0].start_line_number == 1
|
||||
assert result[1].relative_path == "a.py" and result[1].start_line_number == 10
|
||||
assert result[2].relative_path == "b.py" and result[2].start_line_number == 5
|
||||
assert result[3].relative_path == "b.py" and result[3].start_line_number == 20
|
||||
|
||||
|
||||
def test_sort_contexts_end_line_number_tiebreaker():
|
||||
"""Test sorting uses end_line_number as tiebreaker when start_line_number is same"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="test.py", content="content 3", start_line_number=10, end_line_number=30
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 1", start_line_number=10, end_line_number=15
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 2", start_line_number=10, end_line_number=20
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].end_line_number == 15
|
||||
assert result[1].end_line_number == 20
|
||||
assert result[2].end_line_number == 30
|
||||
108
prometheus/tests/utils/test_lang_graph_util.py
Normal file
108
prometheus/tests/utils/test_lang_graph_util.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from prometheus.utils.lang_graph_util import (
|
||||
check_remaining_steps,
|
||||
extract_ai_responses,
|
||||
extract_human_queries,
|
||||
extract_last_tool_messages,
|
||||
format_agent_tool_message_history,
|
||||
get_last_message_content,
|
||||
)
|
||||
|
||||
|
||||
# Test check_remaining_steps
|
||||
def test_check_remaining_steps():
|
||||
def mock_router(state):
|
||||
return "next_step"
|
||||
|
||||
state_enough_steps = {"remaining_steps": 5}
|
||||
state_low_steps = {"remaining_steps": 2}
|
||||
|
||||
assert check_remaining_steps(state_enough_steps, mock_router, 3) == "next_step"
|
||||
assert check_remaining_steps(state_low_steps, mock_router, 3) == "low_remaining_steps"
|
||||
|
||||
|
||||
# Test extract_ai_responses
|
||||
def test_extract_ai_responses():
|
||||
messages = [
|
||||
HumanMessage(content="Human 1"),
|
||||
AIMessage(content="AI 1"),
|
||||
HumanMessage(content="Human 2"),
|
||||
AIMessage(content="AI 2"),
|
||||
]
|
||||
|
||||
responses = extract_ai_responses(messages)
|
||||
assert len(responses) == 2
|
||||
assert "AI 1" in responses
|
||||
assert "AI 2" in responses
|
||||
|
||||
|
||||
# Test extract_human_queries
|
||||
def test_extract_human_queries():
|
||||
messages = [
|
||||
SystemMessage(content="System"),
|
||||
HumanMessage(content="Human 1"),
|
||||
AIMessage(content="AI 1"),
|
||||
HumanMessage(content="Human 2"),
|
||||
]
|
||||
|
||||
queries = extract_human_queries(messages)
|
||||
assert len(queries) == 2
|
||||
assert "Human 1" in queries
|
||||
assert "Human 2" in queries
|
||||
|
||||
|
||||
# Test extract_last_tool_messages
|
||||
def test_extract_last_tool_messages():
|
||||
messages = [
|
||||
HumanMessage(content="Human 1"),
|
||||
ToolMessage(content="Tool 1", tool_call_id="call_1"),
|
||||
AIMessage(content="AI 1"),
|
||||
HumanMessage(content="Human 2"),
|
||||
ToolMessage(content="Tool 2", tool_call_id="call_2"),
|
||||
ToolMessage(content="Tool 3", tool_call_id="call_3"),
|
||||
]
|
||||
|
||||
tool_messages = extract_last_tool_messages(messages)
|
||||
assert len(tool_messages) == 2
|
||||
assert all(isinstance(msg, ToolMessage) for msg in tool_messages)
|
||||
assert tool_messages[-1].content == "Tool 3"
|
||||
|
||||
|
||||
# Test get_last_message_content
|
||||
def test_get_last_message_content():
|
||||
messages = [
|
||||
HumanMessage(content="Human"),
|
||||
AIMessage(content="AI"),
|
||||
ToolMessage(content="Last message", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
content = get_last_message_content(messages)
|
||||
assert content == "Last message"
|
||||
|
||||
|
||||
def test_format_agent_tool_message_history():
|
||||
messages = [
|
||||
AIMessage(content="Let me analyze this"),
|
||||
AIMessage(
|
||||
content="I'll use a tool for this",
|
||||
additional_kwargs={"tool_calls": [{"function": "analyze_data"}]},
|
||||
),
|
||||
ToolMessage(content="Analysis results: Success", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
result = format_agent_tool_message_history(messages)
|
||||
|
||||
expected = (
|
||||
"Assistant internal thought: Let me analyze this\n\n"
|
||||
"Assistant internal thought: I'll use a tool for this\n\n"
|
||||
"Assistant executed tool: analyze_data\n\n"
|
||||
"Tool output: Analysis results: Success"
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
391
prometheus/tests/utils/test_memory_utils.py
Normal file
391
prometheus/tests/utils/test_memory_utils.py
Normal file
@@ -0,0 +1,391 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from prometheus.models.context import Context
|
||||
from prometheus.models.query import Query
|
||||
from prometheus.utils.memory_utils import AthenaMemoryClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def athena_client():
|
||||
"""Create an AthenaMemoryClient instance for testing."""
|
||||
return AthenaMemoryClient(base_url="http://test-athena-service.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_contexts():
|
||||
"""Create sample Context objects for testing."""
|
||||
return [
|
||||
Context(
|
||||
relative_path="src/main.py",
|
||||
content="def main():\n print('Hello')",
|
||||
start_line_number=1,
|
||||
end_line_number=2,
|
||||
),
|
||||
Context(
|
||||
relative_path="src/utils.py",
|
||||
content="def helper():\n return True",
|
||||
start_line_number=10,
|
||||
end_line_number=11,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_query():
|
||||
"""Create a sample Query object for testing."""
|
||||
return Query(
|
||||
essential_query="How to implement authentication?",
|
||||
extra_requirements="Using JWT tokens",
|
||||
purpose="Feature implementation",
|
||||
)
|
||||
|
||||
|
||||
class TestAthenaMemoryClient:
|
||||
"""Test suite for AthenaMemoryClient class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test AthenaMemoryClient initialization."""
|
||||
client = AthenaMemoryClient(base_url="http://example.com/")
|
||||
assert client.base_url == "http://example.com"
|
||||
assert client.timeout == 30
|
||||
|
||||
def test_init_strips_trailing_slash(self):
|
||||
"""Test that trailing slashes are removed from base_url."""
|
||||
client = AthenaMemoryClient(base_url="http://example.com///")
|
||||
assert client.base_url == "http://example.com"
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_success(self, mock_post, athena_client, sample_contexts):
|
||||
"""Test successful memory storage."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success", "memory_id": "123"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call the method
|
||||
result = athena_client.store_memory(
|
||||
repository_id=42,
|
||||
essential_query="How to use async?",
|
||||
extra_requirements="Python 3.11",
|
||||
purpose="Learning",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result == {"status": "success", "memory_id": "123"}
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/store/"
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
# Verify payload structure
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["repository_id"] == 42
|
||||
assert payload["query"]["essential_query"] == "How to use async?"
|
||||
assert payload["query"]["extra_requirements"] == "Python 3.11"
|
||||
assert payload["query"]["purpose"] == "Learning"
|
||||
assert len(payload["contexts"]) == 2
|
||||
assert payload["contexts"][0]["relative_path"] == "src/main.py"
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_request_exception(self, mock_post, athena_client, sample_contexts):
|
||||
"""Test store_memory raises RequestException on failure."""
|
||||
# Setup mock to raise an exception
|
||||
mock_post.side_effect = requests.RequestException("Connection error")
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.RequestException) as exc_info:
|
||||
athena_client.store_memory(
|
||||
repository_id=42,
|
||||
essential_query="test query",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
assert "Connection error" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_http_error(self, mock_post, athena_client, sample_contexts):
|
||||
"""Test store_memory handles HTTP errors."""
|
||||
# Setup mock response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.HTTPError):
|
||||
athena_client.store_memory(
|
||||
repository_id=42,
|
||||
essential_query="test query",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_success(self, mock_get, athena_client, sample_query):
|
||||
"""Test successful memory retrieval."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"content": "Authentication code snippet 1"},
|
||||
{"content": "Authentication code snippet 2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Call the method
|
||||
result = athena_client.retrieve_memory(repository_id=42, query=sample_query)
|
||||
|
||||
# Assertions
|
||||
assert len(result) == 2
|
||||
assert result[0]["content"] == "Authentication code snippet 1"
|
||||
mock_get.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/retrieve/42/"
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
# Verify query parameters
|
||||
params = call_args[1]["params"]
|
||||
assert params["essential_query"] == "How to implement authentication?"
|
||||
assert params["extra_requirements"] == "Using JWT tokens"
|
||||
assert params["purpose"] == "Feature implementation"
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_with_optional_fields(self, mock_get, athena_client):
|
||||
"""Test memory retrieval with empty optional query fields."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Create query with empty optional fields
|
||||
query = Query(essential_query="test query", extra_requirements="", purpose="")
|
||||
|
||||
# Call the method
|
||||
result = athena_client.retrieve_memory(repository_id=42, query=query)
|
||||
|
||||
# Verify empty strings are passed correctly
|
||||
call_args = mock_get.call_args
|
||||
params = call_args[1]["params"]
|
||||
assert params["extra_requirements"] == ""
|
||||
assert params["purpose"] == ""
|
||||
assert result == []
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_request_exception(self, mock_get, athena_client, sample_query):
|
||||
"""Test retrieve_memory raises RequestException on failure."""
|
||||
# Setup mock to raise an exception
|
||||
mock_get.side_effect = requests.RequestException("Timeout error")
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.RequestException) as exc_info:
|
||||
athena_client.retrieve_memory(repository_id=42, query=sample_query)
|
||||
|
||||
assert "Timeout error" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_http_error(self, mock_get, athena_client, sample_query):
|
||||
"""Test retrieve_memory handles HTTP errors."""
|
||||
# Setup mock response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.HTTPError):
|
||||
athena_client.retrieve_memory(repository_id=42, query=sample_query)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.delete")
|
||||
def test_delete_repository_memory_success(self, mock_delete, athena_client):
|
||||
"""Test successful repository memory deletion."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success", "deleted_count": 15}
|
||||
mock_delete.return_value = mock_response
|
||||
|
||||
# Call the method
|
||||
result = athena_client.delete_repository_memory(repository_id=42)
|
||||
|
||||
# Assertions
|
||||
assert result == {"status": "success", "deleted_count": 15}
|
||||
mock_delete.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_delete.call_args
|
||||
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/42/"
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.delete")
|
||||
def test_delete_repository_memory_request_exception(self, mock_delete, athena_client):
|
||||
"""Test delete_repository_memory raises RequestException on failure."""
|
||||
# Setup mock to raise an exception
|
||||
mock_delete.side_effect = requests.RequestException("Network error")
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.RequestException) as exc_info:
|
||||
athena_client.delete_repository_memory(repository_id=42)
|
||||
|
||||
assert "Network error" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.delete")
|
||||
def test_delete_repository_memory_http_error(self, mock_delete, athena_client):
|
||||
"""Test delete_repository_memory handles HTTP errors."""
|
||||
# Setup mock response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("403 Forbidden")
|
||||
mock_delete.return_value = mock_response
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.HTTPError):
|
||||
athena_client.delete_repository_memory(repository_id=42)
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
"""Test suite for module-level convenience functions."""
|
||||
|
||||
@patch("prometheus.utils.memory_utils.athena_client.store_memory")
|
||||
def test_store_memory_function(self, mock_store, sample_contexts):
|
||||
"""Test module-level store_memory function."""
|
||||
from prometheus.utils.memory_utils import store_memory
|
||||
|
||||
# Setup mock
|
||||
mock_store.return_value = {"status": "success"}
|
||||
|
||||
# Call the function
|
||||
result = store_memory(
|
||||
repository_id=123,
|
||||
essential_query="test query",
|
||||
extra_requirements="requirements",
|
||||
purpose="testing",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
# Verify it delegates to the client
|
||||
assert result == {"status": "success"}
|
||||
mock_store.assert_called_once_with(
|
||||
repository_id=123,
|
||||
essential_query="test query",
|
||||
extra_requirements="requirements",
|
||||
purpose="testing",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.athena_client.retrieve_memory")
|
||||
def test_retrieve_memory_function(self, mock_retrieve, sample_query):
|
||||
"""Test module-level retrieve_memory function."""
|
||||
from prometheus.utils.memory_utils import retrieve_memory
|
||||
|
||||
# Setup mock
|
||||
mock_retrieve.return_value = [{"content": "test"}]
|
||||
|
||||
# Call the function
|
||||
result = retrieve_memory(repository_id=123, query=sample_query)
|
||||
|
||||
# Verify it delegates to the client
|
||||
assert result == [{"content": "test"}]
|
||||
mock_retrieve.assert_called_once_with(repository_id=123, query=sample_query)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.athena_client.delete_repository_memory")
|
||||
def test_delete_repository_memory_function(self, mock_delete):
|
||||
"""Test module-level delete_repository_memory function."""
|
||||
from prometheus.utils.memory_utils import delete_repository_memory
|
||||
|
||||
# Setup mock
|
||||
mock_delete.return_value = {"status": "deleted"}
|
||||
|
||||
# Call the function
|
||||
result = delete_repository_memory(repository_id=123)
|
||||
|
||||
# Verify it delegates to the client
|
||||
assert result == {"status": "deleted"}
|
||||
mock_delete.assert_called_once_with(repository_id=123)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and boundary conditions."""
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_with_empty_contexts_list(self, mock_post, athena_client):
|
||||
"""Test storing memory with empty contexts list."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = athena_client.store_memory(
|
||||
repository_id=1,
|
||||
essential_query="test",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
contexts=[],
|
||||
)
|
||||
|
||||
# Verify empty list is handled
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert payload["contexts"] == []
|
||||
assert result == {"status": "success"}
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_empty_results(self, mock_get, athena_client):
|
||||
"""Test retrieving memory when no results are found."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
query = Query(
|
||||
essential_query="nonexistent query",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
)
|
||||
|
||||
result = athena_client.retrieve_memory(repository_id=999, query=query)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_context_serialization(self, sample_contexts):
|
||||
"""Test that Context objects are properly serialized."""
|
||||
context = sample_contexts[0]
|
||||
serialized = context.model_dump()
|
||||
|
||||
assert serialized["relative_path"] == "src/main.py"
|
||||
assert serialized["content"] == "def main():\n print('Hello')"
|
||||
assert serialized["start_line_number"] == 1
|
||||
assert serialized["end_line_number"] == 2
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_with_special_characters(self, mock_post, athena_client):
|
||||
"""Test storing memory with special characters in content."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
special_context = Context(
|
||||
relative_path="test/file.py",
|
||||
content="def test():\n # Special chars: @#$%^&*(){}[]|\\<>?~`",
|
||||
start_line_number=1,
|
||||
end_line_number=2,
|
||||
)
|
||||
|
||||
result = athena_client.store_memory(
|
||||
repository_id=1,
|
||||
essential_query="test with special chars: @#$%",
|
||||
extra_requirements="requirements with 中文",
|
||||
purpose="purpose with émojis 🚀",
|
||||
contexts=[special_context],
|
||||
)
|
||||
|
||||
assert result == {"status": "success"}
|
||||
# Verify special characters are preserved in the payload
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert "@#$%" in payload["query"]["essential_query"]
|
||||
assert "中文" in payload["query"]["extra_requirements"]
|
||||
assert "🚀" in payload["query"]["purpose"]
|
||||
86
prometheus/tests/utils/test_neo4j_util.py
Normal file
86
prometheus/tests/utils/test_neo4j_util.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.utils.knowledge_graph_utils import (
|
||||
EMPTY_DATA_MESSAGE,
|
||||
format_knowledge_graph_data,
|
||||
)
|
||||
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, data_list):
|
||||
self._data = data_list
|
||||
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
|
||||
class MockSession:
|
||||
def __init__(self):
|
||||
self.execute_read = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class MockDriver:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def session(self):
|
||||
return self._session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_neo4j_driver():
|
||||
session = MockSession()
|
||||
driver = MockDriver(session)
|
||||
return driver, session
|
||||
|
||||
|
||||
def test_format_neo4j_data_single_row():
|
||||
data = [{"name": "John", "age": 30}]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\nage: 30\nname: John"
|
||||
|
||||
assert formatted == expected
|
||||
|
||||
|
||||
def test_format_neo4j_result_multiple_rows():
|
||||
data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\nage: 25\nname: Jane"
|
||||
|
||||
assert formatted == expected
|
||||
|
||||
|
||||
def test_format_neo4j_result_empty():
|
||||
data = []
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
assert formatted == EMPTY_DATA_MESSAGE
|
||||
|
||||
|
||||
def test_format_neo4j_result_different_keys():
|
||||
data = [{"name": "John", "age": 30}, {"city": "New York", "country": "USA"}]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\ncity: New York\ncountry: USA"
|
||||
|
||||
assert formatted == expected
|
||||
|
||||
|
||||
def test_format_neo4j_result_complex_values():
|
||||
data = [
|
||||
{"numbers": [1, 2, 3], "metadata": {"type": "user", "active": True}, "date": "2024-01-01"}
|
||||
]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\ndate: 2024-01-01\nmetadata: {'type': 'user', 'active': True}\nnumbers: [1, 2, 3]"
|
||||
|
||||
assert formatted == expected
|
||||
120
prometheus/tests/utils/test_patch_util.py
Normal file
120
prometheus/tests/utils/test_patch_util.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from pathlib import Path
|
||||
|
||||
from prometheus.utils.patch_util import get_updated_files
|
||||
|
||||
|
||||
def test_get_updated_files_empty_diff():
|
||||
diff = ""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 0
|
||||
assert len(modified) == 0
|
||||
assert len(removed) == 0
|
||||
|
||||
|
||||
def test_get_updated_files_added_only():
|
||||
diff = """
|
||||
diff --git a/new_file.txt b/new_file.txt
|
||||
new file mode 100644
|
||||
index 0000000..1234567
|
||||
--- /dev/null
|
||||
+++ b/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+New content
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 1
|
||||
assert len(modified) == 0
|
||||
assert len(removed) == 0
|
||||
assert added[0] == Path("new_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_modified_only():
|
||||
diff = """
|
||||
diff --git a/modified_file.txt b/modified_file.txt
|
||||
index 1234567..89abcdef
|
||||
--- a/modified_file.txt
|
||||
+++ b/modified_file.txt
|
||||
@@ -1 +1 @@
|
||||
-Old content
|
||||
+Modified content
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 0
|
||||
assert len(modified) == 1
|
||||
assert len(removed) == 0
|
||||
assert modified[0] == Path("modified_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_removed_only():
|
||||
diff = """
|
||||
diff --git a/removed_file.txt b/removed_file.txt
|
||||
deleted file mode 100644
|
||||
index 1234567..0000000
|
||||
--- a/removed_file.txt
|
||||
+++ /dev/null
|
||||
@@ -1 +0,0 @@
|
||||
-Content to be removed
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 0
|
||||
assert len(modified) == 0
|
||||
assert len(removed) == 1
|
||||
assert removed[0] == Path("removed_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_multiple_changes():
|
||||
diff = """
|
||||
diff --git a/new_file.txt b/new_file.txt
|
||||
new file mode 100644
|
||||
index 0000000..1234567
|
||||
--- /dev/null
|
||||
+++ b/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+New content
|
||||
diff --git a/modified_file.txt b/modified_file.txt
|
||||
index 1234567..89abcdef
|
||||
--- a/modified_file.txt
|
||||
+++ b/modified_file.txt
|
||||
@@ -1 +1 @@
|
||||
-Old content
|
||||
+Modified content
|
||||
diff --git a/removed_file.txt b/removed_file.txt
|
||||
deleted file mode 100644
|
||||
index 1234567..0000000
|
||||
--- a/removed_file.txt
|
||||
+++ /dev/null
|
||||
@@ -1 +0,0 @@
|
||||
-Content to be removed
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 1
|
||||
assert len(modified) == 1
|
||||
assert len(removed) == 1
|
||||
assert added[0] == Path("new_file.txt")
|
||||
assert modified[0] == Path("modified_file.txt")
|
||||
assert removed[0] == Path("removed_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_with_subfolders():
|
||||
diff = """
|
||||
diff --git a/folder1/new_file.txt b/folder1/new_file.txt
|
||||
new file mode 100644
|
||||
index 0000000..1234567
|
||||
--- /dev/null
|
||||
+++ b/folder1/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+New content
|
||||
diff --git a/folder2/subfolder/modified_file.txt b/folder2/subfolder/modified_file.txt
|
||||
index 1234567..89abcdef
|
||||
--- a/folder2/subfolder/modified_file.txt
|
||||
+++ b/folder2/subfolder/modified_file.txt
|
||||
@@ -1 +1 @@
|
||||
-Old content
|
||||
+Modified content
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 1
|
||||
assert len(modified) == 1
|
||||
assert len(removed) == 0
|
||||
assert added[0] == Path("folder1/new_file.txt")
|
||||
assert modified[0] == Path("folder2/subfolder/modified_file.txt")
|
||||
21
prometheus/tests/utils/test_str_util.py
Normal file
21
prometheus/tests/utils/test_str_util.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from prometheus.utils.str_util import (
|
||||
pre_append_line_numbers,
|
||||
)
|
||||
|
||||
|
||||
def test_single_line():
|
||||
text = "Hello world"
|
||||
result = pre_append_line_numbers(text, start_line=1)
|
||||
assert result == "1. Hello world"
|
||||
|
||||
|
||||
def test_multiple_lines():
|
||||
text = "First line\nSecond line\nThird line"
|
||||
result = pre_append_line_numbers(text, start_line=1)
|
||||
assert result == "1. First line\n2. Second line\n3. Third line"
|
||||
|
||||
|
||||
def test_empty_string():
|
||||
text = ""
|
||||
result = pre_append_line_numbers(text, start_line=1)
|
||||
assert result == ""
|
||||
Reference in New Issue
Block a user