feat: Add intelligent auto-router and enhanced integrations
- Add intelligent-router.sh hook for automatic agent routing - Add AUTO-TRIGGER-SUMMARY.md documentation - Add FINAL-INTEGRATION-SUMMARY.md documentation - Complete Prometheus integration (6 commands + 4 tools) - Complete Dexto integration (12 commands + 5 tools) - Enhanced Ralph with access to all agents - Fix /clawd command (removed disable-model-invocation) - Update hooks.json to v5 with intelligent routing - 291 total skills now available - All 21 commands with automatic routing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
0
prometheus/tests/__init__.py
Normal file
0
prometheus/tests/__init__.py
Normal file
0
prometheus/tests/app/__init__.py
Normal file
0
prometheus/tests/app/__init__.py
Normal file
0
prometheus/tests/app/api/__init__.py
Normal file
0
prometheus/tests/app/api/__init__.py
Normal file
57
prometheus/tests/app/api/test_auth.py
Normal file
57
prometheus/tests/app/api/test_auth.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from prometheus.app.api.routes import auth
|
||||
from prometheus.app.exception_handler import register_exception_handlers
|
||||
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service():
|
||||
service = mock.MagicMock()
|
||||
app.state.service = service
|
||||
yield service
|
||||
|
||||
|
||||
def test_login(mock_service):
|
||||
mock_service["user_service"].login = AsyncMock(return_value="your_access_token")
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"username": "testuser",
|
||||
"email": "test@gmail.com",
|
||||
"password": "passwordpassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {"access_token": "your_access_token"},
|
||||
}
|
||||
|
||||
|
||||
def test_register(mock_service):
|
||||
mock_service["invitation_code_service"].check_invitation_code = AsyncMock(return_value=True)
|
||||
mock_service["user_service"].create_user = AsyncMock(return_value=None)
|
||||
mock_service["invitation_code_service"].mark_code_as_used = AsyncMock(return_value=None)
|
||||
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"username": "testuser",
|
||||
"email": "test@gmail.com",
|
||||
"password": "passwordpassword",
|
||||
"invitation_code": "f23ee204-ff33-401d-8291-1f128d0db08a",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"code": 200, "message": "User registered successfully", "data": None}
|
||||
65
prometheus/tests/app/api/test_github_token.py
Normal file
65
prometheus/tests/app/api/test_github_token.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from prometheus.app.api.routes.github import router
|
||||
|
||||
# Create test app
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/github")
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_issue_data():
|
||||
"""Fixture for mock issue data."""
|
||||
return {
|
||||
"number": 123,
|
||||
"title": "Test Issue",
|
||||
"body": "This is a test issue body",
|
||||
"state": "open",
|
||||
"html_url": "https://github.com/owner/repo/issues/123",
|
||||
"comments": [
|
||||
{"username": "user1", "comment": "First comment"},
|
||||
{"username": "user2", "comment": "Second comment"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_get_github_issue_success(mock_issue_data):
|
||||
"""Test successful retrieval of GitHub issue through the API endpoint."""
|
||||
|
||||
with patch("prometheus.app.api.routes.github.get_github_issue") as mock_get_issue:
|
||||
# Configure the mock
|
||||
mock_get_issue.return_value = mock_issue_data
|
||||
|
||||
# Make the request
|
||||
response = client.get(
|
||||
"/github/issue/",
|
||||
params={"repo": "owner/repo", "issue_number": 123, "github_token": "test_token"},
|
||||
)
|
||||
|
||||
# Assert response status
|
||||
assert response.status_code == 200
|
||||
|
||||
# Parse response
|
||||
response_data = response.json()
|
||||
|
||||
# Assert response structure
|
||||
assert "data" in response_data
|
||||
assert "message" in response_data
|
||||
assert "code" in response_data
|
||||
|
||||
# Assert data content
|
||||
data = response_data["data"]
|
||||
assert data["number"] == 123
|
||||
assert data["title"] == "Test Issue"
|
||||
assert data["body"] == "This is a test issue body"
|
||||
assert data["state"] == "open"
|
||||
assert len(data["comments"]) == 2
|
||||
assert data["comments"][0]["username"] == "user1"
|
||||
|
||||
# Verify the function was called with correct parameters
|
||||
mock_get_issue.assert_called_once_with("owner/repo", 123, "test_token")
|
||||
94
prometheus/tests/app/api/test_invitation_code.py
Normal file
94
prometheus/tests/app/api/test_invitation_code.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import datetime
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from prometheus.app.api.routes import invitation_code
|
||||
from prometheus.app.entity.invitation_code import InvitationCode
|
||||
from prometheus.app.exception_handler import register_exception_handlers
|
||||
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
app.include_router(invitation_code.router, prefix="/invitation-code", tags=["invitation_code"])
|
||||
|
||||
|
||||
@app.middleware("mock_jwt_middleware")
|
||||
async def add_user_id(request: Request, call_next):
|
||||
request.state.user_id = 1 # Set user_id to 1 for testing purposes
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service():
|
||||
service = mock.MagicMock()
|
||||
app.state.service = service
|
||||
yield service
|
||||
|
||||
|
||||
def test_create_invitation_code(mock_service):
|
||||
# Mock the return value of create_invitation_code
|
||||
mock_service["invitation_code_service"].create_invitation_code = AsyncMock(
|
||||
return_value=InvitationCode(
|
||||
id=1,
|
||||
code="testcode",
|
||||
is_used=False,
|
||||
expiration_time=datetime.datetime(
|
||||
year=2025, month=1, day=1, hour=0, minute=0, second=0
|
||||
),
|
||||
)
|
||||
)
|
||||
mock_service["user_service"].is_admin = AsyncMock(return_value=True)
|
||||
|
||||
# Test the creation endpoint
|
||||
response = client.post("invitation-code/create/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"id": 1,
|
||||
"code": "testcode",
|
||||
"is_used": False,
|
||||
"expiration_time": "2025-01-01T00:00:00",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_list(mock_service):
|
||||
# Mock user as admin and return a list of invitation codes
|
||||
mock_service["invitation_code_service"].list_invitation_codes = AsyncMock(
|
||||
return_value=[
|
||||
InvitationCode(
|
||||
id=1,
|
||||
code="testcode",
|
||||
is_used=False,
|
||||
expiration_time=datetime.datetime(
|
||||
year=2025, month=1, day=1, hour=0, minute=0, second=0
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_service["user_service"].is_admin = AsyncMock(return_value=True)
|
||||
|
||||
# Test the list endpoint
|
||||
response = client.get("invitation-code/list/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 1,
|
||||
"code": "testcode",
|
||||
"is_used": False,
|
||||
"expiration_time": "2025-01-01T00:00:00",
|
||||
}
|
||||
],
|
||||
}
|
||||
174
prometheus/tests/app/api/test_issue.py
Normal file
174
prometheus/tests/app/api/test_issue.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from prometheus.app.api.routes import issue
|
||||
from prometheus.app.entity.repository import Repository
|
||||
from prometheus.app.exception_handler import register_exception_handlers
|
||||
from prometheus.lang_graph.graphs.issue_state import IssueType
|
||||
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
app.include_router(issue.router, prefix="/issue", tags=["issue"])
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service():
|
||||
service = mock.MagicMock()
|
||||
app.state.service = service
|
||||
yield service
|
||||
|
||||
|
||||
def test_answer_issue(mock_service):
|
||||
mock_service["repository_service"].get_repository_by_id = AsyncMock(
|
||||
return_value=Repository(
|
||||
id=1,
|
||||
url="https://github.com/fake/repo.git",
|
||||
commit_id=None,
|
||||
playground_path="/path/to/playground",
|
||||
kg_root_node_id=0,
|
||||
user_id=None,
|
||||
kg_max_ast_depth=100,
|
||||
kg_chunk_size=1000,
|
||||
kg_chunk_overlap=100,
|
||||
)
|
||||
)
|
||||
mock_service["knowledge_graph_service"].get_knowledge_graph = AsyncMock(
|
||||
return_value=mock.MagicMock()
|
||||
)
|
||||
mock_service["repository_service"].update_repository_status = AsyncMock(return_value=None)
|
||||
mock_service["issue_service"].answer_issue.return_value = (
|
||||
"test patch", # patch
|
||||
True, # passed_reproducing_test
|
||||
True, # passed_regression_test
|
||||
True, # passed_existing_test
|
||||
"Issue fixed", # issue_response
|
||||
IssueType.BUG, # issue_type
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/issue/answer/",
|
||||
json={
|
||||
"repository_id": 1,
|
||||
"issue_title": "Test Issue",
|
||||
"issue_body": "Test description",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"patch": "test patch",
|
||||
"passed_reproducing_test": True,
|
||||
"passed_regression_test": True,
|
||||
"passed_existing_test": True,
|
||||
"issue_response": "Issue fixed",
|
||||
"issue_type": "bug",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_answer_issue_no_repository(mock_service):
|
||||
mock_service["repository_service"].get_repository_by_id = AsyncMock(return_value=None)
|
||||
|
||||
response = client.post(
|
||||
"/issue/answer/",
|
||||
json={
|
||||
"repository_id": 1,
|
||||
"issue_title": "Test Issue",
|
||||
"issue_body": "Test description",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_answer_issue_invalid_container_config(mock_service):
|
||||
mock_service["repository_service"].get_repository_by_id = AsyncMock(
|
||||
return_value=Repository(
|
||||
id=1,
|
||||
url="https://github.com/fake/repo.git",
|
||||
commit_id=None,
|
||||
playground_path="/path/to/playground",
|
||||
kg_root_node_id=0,
|
||||
user_id=None,
|
||||
kg_max_ast_depth=100,
|
||||
kg_chunk_size=1000,
|
||||
kg_chunk_overlap=100,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/issue/answer/",
|
||||
json={
|
||||
"repository_id": 1,
|
||||
"issue_title": "Test Issue",
|
||||
"issue_body": "Test description",
|
||||
"dockerfile_content": "FROM python:3.11",
|
||||
"workdir": None,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_answer_issue_with_container(mock_service):
|
||||
mock_service["repository_service"].get_repository_by_id = AsyncMock(
|
||||
return_value=Repository(
|
||||
id=1,
|
||||
url="https://github.com/fake/repo.git",
|
||||
commit_id=None,
|
||||
playground_path="/path/to/playground",
|
||||
kg_root_node_id=0,
|
||||
user_id=None,
|
||||
kg_max_ast_depth=100,
|
||||
kg_chunk_size=1000,
|
||||
kg_chunk_overlap=100,
|
||||
)
|
||||
)
|
||||
|
||||
mock_service["issue_service"].answer_issue.return_value = (
|
||||
"test patch",
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
"Issue fixed",
|
||||
IssueType.BUG,
|
||||
)
|
||||
mock_service["knowledge_graph_service"].get_knowledge_graph = AsyncMock(
|
||||
return_value=mock.MagicMock()
|
||||
)
|
||||
mock_service["repository_service"].update_repository_status = AsyncMock(return_value=None)
|
||||
|
||||
test_payload = {
|
||||
"repository_id": 1,
|
||||
"issue_title": "Test Issue",
|
||||
"issue_body": "Test description",
|
||||
"dockerfile_content": "FROM python:3.11",
|
||||
"run_reproduce_test": True,
|
||||
"workdir": "/app",
|
||||
"build_commands": ["pip install -r requirements.txt"],
|
||||
"test_commands": ["pytest ."],
|
||||
}
|
||||
|
||||
response = client.post("/issue/answer/", json=test_payload)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"patch": "test patch",
|
||||
"passed_reproducing_test": True,
|
||||
"passed_regression_test": True,
|
||||
"passed_existing_test": True,
|
||||
"issue_response": "Issue fixed",
|
||||
"issue_type": "bug",
|
||||
},
|
||||
}
|
||||
168
prometheus/tests/app/api/test_repository.py
Normal file
168
prometheus/tests/app/api/test_repository.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from prometheus.app.api.routes import repository
|
||||
from prometheus.app.entity.repository import Repository
|
||||
from prometheus.app.exception_handler import register_exception_handlers
|
||||
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
app.include_router(repository.router, prefix="/repository", tags=["repository"])
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service():
|
||||
service = mock.MagicMock()
|
||||
app.state.service = service
|
||||
yield service
|
||||
|
||||
|
||||
def test_upload_repository(mock_service):
|
||||
mock_service["repository_service"].clone_github_repo = AsyncMock(return_value="/mock/path")
|
||||
mock_service["repository_service"].get_repository_by_url_and_commit_id = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
mock_service["repository_service"].create_new_repository = AsyncMock(return_value=1)
|
||||
mock_service["knowledge_graph_service"].build_and_save_knowledge_graph = AsyncMock(
|
||||
return_value=0
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/repository/upload",
|
||||
json={
|
||||
"github_token": "mock_token",
|
||||
"https_url": "https://github.com/Pantheon-temple/Prometheus",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {"repository_id": 1},
|
||||
}
|
||||
|
||||
|
||||
def test_upload_repository_at_commit(mock_service):
|
||||
mock_service["repository_service"].clone_github_repo = AsyncMock(return_value="/mock/path")
|
||||
mock_service["repository_service"].get_repository_by_url_and_commit_id = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
mock_service["repository_service"].create_new_repository = AsyncMock(return_value=1)
|
||||
mock_service["knowledge_graph_service"].build_and_save_knowledge_graph = AsyncMock(
|
||||
return_value=0
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/repository/upload/",
|
||||
json={
|
||||
"github_token": "mock_token",
|
||||
"https_url": "https://github.com/Pantheon-temple/Prometheus",
|
||||
"commit_id": "0c554293648a8705769fa53ec896ae24da75f4fc",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_create_branch_and_push(mock_service):
|
||||
# Mock git_repo
|
||||
git_repo_mock = MagicMock()
|
||||
git_repo_mock.create_and_push_branch = AsyncMock(return_value=None)
|
||||
|
||||
# Let repository_service.get_repository return the mocked git_repo
|
||||
mock_service["repository_service"].get_repository.return_value = git_repo_mock
|
||||
mock_service["repository_service"].get_repository_by_id = AsyncMock(
|
||||
return_value=Repository(
|
||||
id=1,
|
||||
url="https://github.com/fake/repo.git",
|
||||
commit_id=None,
|
||||
playground_path="/path/to/playground",
|
||||
kg_root_node_id=0,
|
||||
user_id=None,
|
||||
kg_max_ast_depth=100,
|
||||
kg_chunk_size=1000,
|
||||
kg_chunk_overlap=100,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/repository/create-branch-and-push/",
|
||||
json={
|
||||
"repository_id": 1,
|
||||
"branch_name": "new_branch",
|
||||
"commit_message": "Initial commit on new branch",
|
||||
"patch": "mock_patch_content",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@mock.patch("prometheus.app.api.routes.repository.delete_repository_memory")
|
||||
def test_delete(mock_delete_memory, mock_service):
|
||||
# Mock the delete_repository_memory to return success
|
||||
mock_delete_memory.return_value = {"code": 200, "message": "success", "data": None}
|
||||
|
||||
mock_service["repository_service"].get_repository_by_id = AsyncMock(
|
||||
return_value=Repository(
|
||||
id=1,
|
||||
url="https://github.com/fake/repo.git",
|
||||
commit_id=None,
|
||||
playground_path="/path/to/playground",
|
||||
kg_root_node_id=0,
|
||||
user_id=None,
|
||||
kg_max_ast_depth=100,
|
||||
kg_chunk_size=1000,
|
||||
kg_chunk_overlap=100,
|
||||
)
|
||||
)
|
||||
mock_service["knowledge_graph_service"].clear_kg = AsyncMock(return_value=None)
|
||||
mock_service["repository_service"].clean_repository.return_value = None
|
||||
mock_service["repository_service"].delete_repository = AsyncMock(return_value=None)
|
||||
response = client.delete(
|
||||
"repository/delete",
|
||||
params={
|
||||
"repository_id": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_list(mock_service):
|
||||
mock_service["repository_service"].get_all_repositories = AsyncMock(
|
||||
return_value=[
|
||||
Repository(
|
||||
id=1,
|
||||
url="https://github.com/fake/repo.git",
|
||||
commit_id=None,
|
||||
playground_path="/path/to/playground",
|
||||
kg_root_node_id=0,
|
||||
user_id=None,
|
||||
kg_max_ast_depth=100,
|
||||
kg_chunk_size=1000,
|
||||
kg_chunk_overlap=100,
|
||||
)
|
||||
]
|
||||
)
|
||||
response = client.get("repository/list/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 1,
|
||||
"url": "https://github.com/fake/repo.git",
|
||||
"commit_id": None,
|
||||
"is_working": False,
|
||||
"user_id": None,
|
||||
"kg_max_ast_depth": 100,
|
||||
"kg_chunk_size": 1000,
|
||||
"kg_chunk_overlap": 100,
|
||||
}
|
||||
],
|
||||
}
|
||||
82
prometheus/tests/app/api/test_user.py
Normal file
82
prometheus/tests/app/api/test_user.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from prometheus.app.api.routes import user
|
||||
from prometheus.app.entity.user import User
|
||||
from prometheus.app.exception_handler import register_exception_handlers
|
||||
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
app.include_router(user.router, prefix="/user", tags=["user"])
|
||||
|
||||
|
||||
@app.middleware("mock_jwt_middleware")
|
||||
async def add_user_id(request: Request, call_next):
|
||||
request.state.user_id = 1 # Set user_id to 1 for testing purposes
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service():
|
||||
service = mock.MagicMock()
|
||||
app.state.service = service
|
||||
yield service
|
||||
|
||||
|
||||
def test_list(mock_service):
|
||||
# Mock user as admin and return a list of users
|
||||
mock_service["user_service"].list_users = AsyncMock(
|
||||
return_value=[
|
||||
User(
|
||||
id=1,
|
||||
username="testuser",
|
||||
email="test@gmail.com",
|
||||
password_hash="hashedpassword",
|
||||
github_token="ghp_1234567890abcdef1234567890abcdef1234",
|
||||
issue_credit=10,
|
||||
is_superuser=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_service["user_service"].is_admin = AsyncMock(return_value=True)
|
||||
|
||||
# Test the list endpoint
|
||||
response = client.get("user/list/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 1,
|
||||
"username": "testuser",
|
||||
"email": "test@gmail.com",
|
||||
"issue_credit": 10,
|
||||
"is_superuser": False,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_set_github_token(mock_service):
|
||||
# Mock user as admin and return a list of users
|
||||
mock_service["user_service"].set_github_token = AsyncMock(return_value=None)
|
||||
|
||||
# Test the list endpoint
|
||||
response = client.put(
|
||||
"user/set-github-token/", json={"github_token": "ghp_1234567890abcdef1234567890abcdef1234"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": None,
|
||||
}
|
||||
0
prometheus/tests/app/middlewares/__init__.py
Normal file
0
prometheus/tests/app/middlewares/__init__.py
Normal file
138
prometheus/tests/app/middlewares/test_jwt_middleware.py
Normal file
138
prometheus/tests/app/middlewares/test_jwt_middleware.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from prometheus.app.middlewares.jwt_middleware import JWTMiddleware
|
||||
from prometheus.exceptions.jwt_exception import JWTException
|
||||
from prometheus.utils.jwt_utils import JWTUtils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""
|
||||
Create a FastAPI app with JWTMiddleware installed.
|
||||
We mark certain (method, path) pairs as login-required.
|
||||
"""
|
||||
app = FastAPI()
|
||||
|
||||
# Only these routes require login:
|
||||
login_required_routes = {
|
||||
("GET", "/protected"),
|
||||
("OPTIONS", "/protected"), # include options here if you want middleware to check it
|
||||
("GET", "/me"),
|
||||
}
|
||||
|
||||
app.add_middleware(JWTMiddleware, login_required_routes=login_required_routes)
|
||||
|
||||
@app.get("/public")
|
||||
def public():
|
||||
return {"ok": True, "route": "public"}
|
||||
|
||||
@app.get("/protected")
|
||||
def protected(request: Request):
|
||||
# Return back the user_id the middleware stores on request.state
|
||||
return {
|
||||
"ok": True,
|
||||
"route": "protected",
|
||||
"user_id": getattr(request.state, "user_id", None),
|
||||
}
|
||||
|
||||
@app.get("/me")
|
||||
def me(request: Request):
|
||||
return {"user_id": getattr(request.state, "user_id", None)}
|
||||
|
||||
# Explicit OPTIONS route to ensure 200/204 so we can assert behavior
|
||||
@app.options("/protected")
|
||||
def options_protected():
|
||||
return Response(status_code=204)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_non_protected_route_bypasses_auth(client):
|
||||
"""
|
||||
Requests to routes not listed in login_required_routes must bypass JWT check.
|
||||
"""
|
||||
resp = client.get("/public")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["route"] == "public"
|
||||
|
||||
|
||||
def test_missing_authorization_returns_401_on_protected(client):
|
||||
"""
|
||||
Missing Authorization header on a protected endpoint should return 401.
|
||||
"""
|
||||
resp = client.get("/protected")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["code"] == 401
|
||||
assert "Valid JWT Token is missing" in body["message"]
|
||||
|
||||
|
||||
def test_wrong_scheme_returns_401_on_protected(client):
|
||||
"""
|
||||
Wrong Authorization scheme (not Bearer) should return 401 on protected endpoint.
|
||||
"""
|
||||
resp = client.get("/protected", headers={"Authorization": "Token abc.def.ghi"})
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["code"] == 401
|
||||
assert "Valid JWT Token is missing" in body["message"]
|
||||
|
||||
|
||||
def test_invalid_token_raises_and_returns_error(client, monkeypatch):
|
||||
"""
|
||||
If JWTUtils.decode_token raises JWTException, middleware should map it to the response.
|
||||
"""
|
||||
|
||||
def fake_decode(_self, _: str):
|
||||
raise JWTException(code=403, message="Invalid or expired token")
|
||||
|
||||
# Patch the method on the class; middleware instantiates JWTUtils() internally
|
||||
monkeypatch.setattr(JWTUtils, "decode_token", fake_decode, raising=True)
|
||||
|
||||
resp = client.get("/protected", headers={"Authorization": "Bearer bad.token"})
|
||||
assert resp.status_code == 403
|
||||
body = resp.json()
|
||||
assert body["code"] == 403
|
||||
assert body["message"] == "Invalid or expired token"
|
||||
|
||||
|
||||
def test_valid_token_sets_user_id_and_passes(client, monkeypatch):
|
||||
"""
|
||||
With a valid token, request should pass and user_id should be present on request.state.
|
||||
"""
|
||||
|
||||
def fake_decode(_self, _: str):
|
||||
# Return payload with user_id as middleware expects
|
||||
return {"user_id": 123}
|
||||
|
||||
monkeypatch.setattr(JWTUtils, "decode_token", fake_decode, raising=True)
|
||||
|
||||
resp = client.get("/protected", headers={"Authorization": "Bearer good.token"})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["ok"] is True
|
||||
assert body["user_id"] == 123
|
||||
|
||||
|
||||
def test_options_request_passes_through(client, monkeypatch):
|
||||
"""
|
||||
OPTIONS preflight should be allowed through without requiring a valid token.
|
||||
The middleware explicitly bypasses OPTIONS before checking Authorization.
|
||||
"""
|
||||
|
||||
# Even if decode_token would fail, OPTIONS should not trigger it.
|
||||
def boom(_self, _: str):
|
||||
raise AssertionError("decode_token should not be called for OPTIONS")
|
||||
|
||||
monkeypatch.setattr(JWTUtils, "decode_token", boom, raising=True)
|
||||
|
||||
resp = client.options("/protected")
|
||||
# Our route returns 204; any 2xx is acceptable depending on your route
|
||||
assert resp.status_code in (200, 204)
|
||||
0
prometheus/tests/app/services/__init__.py
Normal file
0
prometheus/tests/app/services/__init__.py
Normal file
17
prometheus/tests/app/services/test_database_service.py
Normal file
17
prometheus/tests/app/services/test_database_service.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.app.services.database_service import DatabaseService
|
||||
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_database_service(postgres_container_fixture): # noqa: F811
|
||||
url = postgres_container_fixture.get_connection_url()
|
||||
database_service = DatabaseService(url)
|
||||
assert database_service.engine is not None
|
||||
|
||||
try:
|
||||
await database_service.start()
|
||||
await database_service.close()
|
||||
except Exception as e:
|
||||
pytest.fail(f"Connection verification failed: {e}")
|
||||
122
prometheus/tests/app/services/test_invitation_code.py
Normal file
122
prometheus/tests/app/services/test_invitation_code.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from prometheus.app.entity.invitation_code import InvitationCode
|
||||
from prometheus.app.services.database_service import DatabaseService
|
||||
from prometheus.app.services.invitation_code_service import InvitationCodeService
|
||||
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_database_service(postgres_container_fixture): # noqa: F811
|
||||
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
|
||||
service = DatabaseService(postgres_container_fixture.get_connection_url())
|
||||
await service.start()
|
||||
yield service
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(mock_database_service):
|
||||
"""Fixture: construct an InvitationCodeService with the database service."""
|
||||
return InvitationCodeService(database_service=mock_database_service)
|
||||
|
||||
|
||||
async def _insert_code(
|
||||
session: AsyncSession, code: str, is_used: bool = False, expires_in_seconds: int = 3600
|
||||
) -> InvitationCode:
|
||||
"""Helper: insert a single InvitationCode with given state and expiration."""
|
||||
obj = InvitationCode(
|
||||
code=code,
|
||||
is_used=is_used,
|
||||
expiration_time=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
|
||||
)
|
||||
session.add(obj)
|
||||
await session.commit()
|
||||
await session.refresh(obj)
|
||||
return obj
|
||||
|
||||
|
||||
async def test_create_invitation_code(service):
|
||||
"""Test that create_invitation_code correctly generates and returns an InvitationCode."""
|
||||
invitation_code = await service.create_invitation_code()
|
||||
|
||||
# Verify the returned object is an InvitationCode instance
|
||||
assert isinstance(invitation_code, InvitationCode)
|
||||
assert isinstance(invitation_code.code, str)
|
||||
assert len(invitation_code.code) == 36 # uuid4 string length
|
||||
assert invitation_code.id is not None
|
||||
|
||||
# Verify the object is persisted in the database
|
||||
async with AsyncSession(service.engine) as session:
|
||||
db_obj = await session.get(InvitationCode, invitation_code.id)
|
||||
assert db_obj is not None
|
||||
assert db_obj.code == invitation_code.code
|
||||
|
||||
|
||||
async def test_list_invitation_codes(service):
|
||||
"""Test that list_invitation_codes returns all stored invitation codes."""
|
||||
# Insert two invitation codes first
|
||||
code1 = await service.create_invitation_code()
|
||||
code2 = await service.create_invitation_code()
|
||||
|
||||
codes = await service.list_invitation_codes()
|
||||
|
||||
# Verify length
|
||||
assert len(codes) >= 2
|
||||
# Verify both created codes are included
|
||||
all_codes = [c.code for c in codes]
|
||||
assert code1.code in all_codes
|
||||
assert code2.code in all_codes
|
||||
|
||||
|
||||
async def test_check_invitation_code_returns_false_when_not_exists(service):
|
||||
"""check_invitation_code should return False if the code does not exist."""
|
||||
ok = await service.check_invitation_code("non-existent-code")
|
||||
assert ok is False
|
||||
|
||||
|
||||
async def test_check_invitation_code_returns_false_when_used(service):
|
||||
"""check_invitation_code should return False if the code is already used."""
|
||||
async with AsyncSession(service.engine) as session:
|
||||
await _insert_code(session, "used-code", is_used=True, expires_in_seconds=3600)
|
||||
|
||||
ok = await service.check_invitation_code("used-code")
|
||||
assert ok is False
|
||||
|
||||
|
||||
async def test_check_invitation_code_returns_false_when_expired(service):
|
||||
"""check_invitation_code should return False if the code is expired."""
|
||||
async with AsyncSession(service.engine) as session:
|
||||
# Negative expires_in_seconds makes it expire in the past
|
||||
await _insert_code(session, "expired-code", is_used=False, expires_in_seconds=-60)
|
||||
|
||||
ok = await service.check_invitation_code("expired-code")
|
||||
assert ok is False
|
||||
|
||||
|
||||
async def test_check_invitation_code_returns_true_when_valid(service):
|
||||
"""check_invitation_code should return True if the code exists, not used, and not expired."""
|
||||
async with AsyncSession(service.engine) as session:
|
||||
await _insert_code(session, "valid-code", is_used=False, expires_in_seconds=3600)
|
||||
|
||||
ok = await service.check_invitation_code("valid-code")
|
||||
assert ok is True
|
||||
|
||||
|
||||
async def test_mark_code_as_used_persists_state(service):
|
||||
"""mark_code_as_used should set 'used' to True and persist to DB."""
|
||||
async with AsyncSession(service.engine) as session:
|
||||
created = await _insert_code(session, "to-use", is_used=False, expires_in_seconds=3600)
|
||||
created_id = created.id
|
||||
|
||||
# Act
|
||||
await service.mark_code_as_used("to-use")
|
||||
|
||||
# Assert persisted state
|
||||
async with AsyncSession(service.engine) as session:
|
||||
refreshed = await session.get(InvitationCode, created_id)
|
||||
assert refreshed is not None
|
||||
assert refreshed.is_used is True
|
||||
157
prometheus/tests/app/services/test_issue_service.py
Normal file
157
prometheus/tests/app/services/test_issue_service.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from unittest.mock import Mock, create_autospec
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.app.services.issue_service import IssueService
|
||||
from prometheus.app.services.llm_service import LLMService
|
||||
from prometheus.app.services.repository_service import RepositoryService
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.graphs.issue_state import IssueType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_service():
|
||||
service = create_autospec(LLMService, instance=True)
|
||||
service.advanced_model = "gpt-4"
|
||||
service.base_model = "gpt-3.5-turbo"
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repository_service():
|
||||
service = create_autospec(RepositoryService, instance=True)
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def issue_service(mock_llm_service, mock_repository_service):
|
||||
return IssueService(
|
||||
llm_service=mock_llm_service,
|
||||
working_directory="/tmp/working_dir/",
|
||||
logging_level="DEBUG",
|
||||
)
|
||||
|
||||
|
||||
async def test_answer_issue_with_general_container(issue_service, monkeypatch):
|
||||
# Setup
|
||||
mock_issue_graph = Mock()
|
||||
mock_issue_graph_class = Mock(return_value=mock_issue_graph)
|
||||
monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class)
|
||||
|
||||
mock_container = Mock()
|
||||
mock_general_container_class = Mock(return_value=mock_container)
|
||||
monkeypatch.setattr(
|
||||
"prometheus.app.services.issue_service.GeneralContainer", mock_general_container_class
|
||||
)
|
||||
|
||||
repository = Mock(spec=GitRepository)
|
||||
repository.get_working_directory.return_value = "mock/working/directory"
|
||||
|
||||
knowledge_graph = Mock(spec=KnowledgeGraph)
|
||||
|
||||
# Mock output state for a bug type
|
||||
mock_output_state = {
|
||||
"issue_type": IssueType.BUG,
|
||||
"edit_patch": "test_patch",
|
||||
"passed_reproducing_test": True,
|
||||
"passed_regression_test": True,
|
||||
"passed_existing_test": True,
|
||||
"issue_response": "test_response",
|
||||
}
|
||||
mock_issue_graph.invoke.return_value = mock_output_state
|
||||
|
||||
# Exercise
|
||||
result = issue_service.answer_issue(
|
||||
repository=repository,
|
||||
knowledge_graph=knowledge_graph,
|
||||
repository_id=1,
|
||||
issue_title="Test Issue",
|
||||
issue_body="Test Body",
|
||||
issue_comments=[],
|
||||
issue_type=IssueType.BUG,
|
||||
run_build=True,
|
||||
run_regression_test=True,
|
||||
run_existing_test=True,
|
||||
run_reproduce_test=True,
|
||||
number_of_candidate_patch=1,
|
||||
build_commands=None,
|
||||
test_commands=None,
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_general_container_class.assert_called_once_with(
|
||||
project_path=repository.get_working_directory(), build_commands=None, test_commands=None
|
||||
)
|
||||
|
||||
mock_issue_graph_class.assert_called_once_with(
|
||||
advanced_model=issue_service.llm_service.advanced_model,
|
||||
base_model=issue_service.llm_service.base_model,
|
||||
kg=knowledge_graph,
|
||||
git_repo=repository,
|
||||
container=mock_container,
|
||||
repository_id=1,
|
||||
test_commands=None,
|
||||
)
|
||||
assert result == ("test_patch", True, True, True, "test_response", IssueType.BUG)
|
||||
|
||||
|
||||
async def test_answer_issue_with_user_defined_container(issue_service, monkeypatch):
|
||||
# Setup
|
||||
mock_issue_graph = Mock()
|
||||
mock_issue_graph_class = Mock(return_value=mock_issue_graph)
|
||||
monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class)
|
||||
|
||||
mock_container = Mock()
|
||||
mock_user_container_class = Mock(return_value=mock_container)
|
||||
monkeypatch.setattr(
|
||||
"prometheus.app.services.issue_service.UserDefinedContainer", mock_user_container_class
|
||||
)
|
||||
|
||||
repository = Mock(spec=GitRepository)
|
||||
repository.get_working_directory.return_value = "mock/working/directory"
|
||||
|
||||
knowledge_graph = Mock(spec=KnowledgeGraph)
|
||||
|
||||
# Mock output state for a question type
|
||||
mock_output_state = {
|
||||
"issue_type": IssueType.QUESTION,
|
||||
"edit_patch": None,
|
||||
"passed_reproducing_test": False,
|
||||
"passed_regression_test": False,
|
||||
"passed_existing_test": False,
|
||||
"issue_response": "test_response",
|
||||
}
|
||||
mock_issue_graph.invoke.return_value = mock_output_state
|
||||
|
||||
# Exercise
|
||||
result = issue_service.answer_issue(
|
||||
repository=repository,
|
||||
knowledge_graph=knowledge_graph,
|
||||
repository_id=1,
|
||||
issue_title="Test Issue",
|
||||
issue_body="Test Body",
|
||||
issue_comments=[],
|
||||
issue_type=IssueType.QUESTION,
|
||||
run_build=True,
|
||||
run_regression_test=True,
|
||||
run_existing_test=True,
|
||||
run_reproduce_test=True,
|
||||
number_of_candidate_patch=1,
|
||||
dockerfile_content="FROM python:3.8",
|
||||
image_name="test-image",
|
||||
workdir="/app",
|
||||
build_commands=["pip install -r requirements.txt"],
|
||||
test_commands=["pytest"],
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_user_container_class.assert_called_once_with(
|
||||
project_path=repository.get_working_directory(),
|
||||
workdir="/app",
|
||||
build_commands=["pip install -r requirements.txt"],
|
||||
test_commands=["pytest"],
|
||||
dockerfile_content="FROM python:3.8",
|
||||
image_name="test-image",
|
||||
)
|
||||
assert result == (None, False, False, False, "test_response", IssueType.QUESTION)
|
||||
104
prometheus/tests/app/services/test_knowledge_graph_service.py
Normal file
104
prometheus/tests/app/services/test_knowledge_graph_service.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.app.services.knowledge_graph_service import KnowledgeGraphService
|
||||
from prometheus.app.services.neo4j_service import Neo4jService
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_neo4j_service():
|
||||
"""Mock the Neo4jService."""
|
||||
neo4j_service = MagicMock(Neo4jService)
|
||||
return neo4j_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg_handler():
|
||||
"""Mock the KnowledgeGraphHandler."""
|
||||
kg_handler = MagicMock(KnowledgeGraphHandler)
|
||||
return kg_handler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_graph_service(mock_neo4j_service, mock_kg_handler):
|
||||
"""Fixture to create KnowledgeGraphService instance."""
|
||||
mock_neo4j_service.neo4j_driver = MagicMock() # Mocking Neo4j driver
|
||||
mock_kg_handler.get_new_knowledge_graph_root_node_id = AsyncMock(return_value=123)
|
||||
mock_kg_handler.write_knowledge_graph = AsyncMock()
|
||||
|
||||
knowledge_graph_service = KnowledgeGraphService(
|
||||
neo4j_service=mock_neo4j_service,
|
||||
neo4j_batch_size=1000,
|
||||
max_ast_depth=5,
|
||||
chunk_size=1000,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
knowledge_graph_service.kg_handler = mock_kg_handler
|
||||
return knowledge_graph_service
|
||||
|
||||
|
||||
async def test_build_and_save_knowledge_graph(knowledge_graph_service, mock_kg_handler):
|
||||
"""Test the build_and_save_knowledge_graph method."""
|
||||
# Given
|
||||
source_code_path = Path("/mock/path/to/source/code") # Mock path to source code
|
||||
|
||||
# Mock KnowledgeGraph and its methods
|
||||
mock_kg = MagicMock(KnowledgeGraph)
|
||||
mock_kg.build_graph = AsyncMock(return_value=None) # Mock async method to build graph
|
||||
mock_kg_handler.get_new_knowledge_graph_root_node_id = AsyncMock(
|
||||
return_value=123
|
||||
) # Mock return value
|
||||
mock_kg_handler.write_knowledge_graph = AsyncMock(return_value=None)
|
||||
|
||||
# When
|
||||
with pytest.raises(Exception):
|
||||
knowledge_graph_service.kg_handler = mock_kg_handler
|
||||
result = await knowledge_graph_service.build_and_save_knowledge_graph(source_code_path)
|
||||
|
||||
# Then
|
||||
assert result == 123 # Ensure that the correct root node ID is returned
|
||||
mock_kg.build_graph.assert_awaited_once_with(
|
||||
source_code_path
|
||||
) # Check if build_graph was called correctly
|
||||
mock_kg_handler.write_knowledge_graph.assert_called_once() # Ensure graph write happened
|
||||
mock_kg_handler.get_new_knowledge_graph_root_node_id.assert_called_once() # Ensure the root node ID was fetched
|
||||
|
||||
|
||||
async def test_clear_kg(knowledge_graph_service, mock_kg_handler):
|
||||
"""Test the clear_kg method."""
|
||||
# Given
|
||||
root_node_id = 123 # Mock root node ID
|
||||
|
||||
# When
|
||||
await knowledge_graph_service.clear_kg(root_node_id)
|
||||
|
||||
# Then
|
||||
mock_kg_handler.clear_knowledge_graph.assert_called_once_with(root_node_id)
|
||||
|
||||
|
||||
async def test_get_knowledge_graph(knowledge_graph_service, mock_kg_handler):
|
||||
"""Test the get_knowledge_graph method."""
|
||||
# Given
|
||||
root_node_id = 123 # Mock root node ID
|
||||
max_ast_depth = 5
|
||||
chunk_size = 1000
|
||||
chunk_overlap = 100
|
||||
|
||||
# Mock KnowledgeGraph
|
||||
mock_kg = MagicMock(KnowledgeGraph)
|
||||
mock_kg_handler.read_knowledge_graph = AsyncMock(return_value=mock_kg) # Mock return value
|
||||
|
||||
# When
|
||||
result = await knowledge_graph_service.get_knowledge_graph(
|
||||
root_node_id, max_ast_depth, chunk_size, chunk_overlap
|
||||
)
|
||||
|
||||
# Then
|
||||
mock_kg_handler.read_knowledge_graph.assert_called_once_with(
|
||||
root_node_id, max_ast_depth, chunk_size, chunk_overlap
|
||||
) # Ensure read_knowledge_graph is called with the correct parameters
|
||||
assert result == mock_kg # Ensure the correct KnowledgeGraph object is returned
|
||||
125
prometheus/tests/app/services/test_llm_service.py
Normal file
125
prometheus/tests/app/services/test_llm_service.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.app.services.llm_service import CustomChatOpenAI, LLMService, get_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_chat_openai():
|
||||
with patch("prometheus.app.services.llm_service.CustomChatOpenAI") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_anthropic():
|
||||
with patch("prometheus.app.services.llm_service.ChatAnthropic") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_google():
|
||||
with patch("prometheus.app.services.llm_service.ChatGoogleGenerativeAI") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_llm_service_init(mock_custom_chat_openai, mock_chat_anthropic):
|
||||
# Setup
|
||||
mock_gpt_instance = Mock()
|
||||
mock_claude_instance = Mock()
|
||||
mock_custom_chat_openai.return_value = mock_gpt_instance
|
||||
mock_chat_anthropic.return_value = mock_claude_instance
|
||||
|
||||
# Exercise
|
||||
service = LLMService(
|
||||
advanced_model_name="gpt-4",
|
||||
base_model_name="claude-2.1",
|
||||
advanced_model_temperature=0.0,
|
||||
base_model_temperature=0.0,
|
||||
openai_format_api_key="openai-key",
|
||||
openai_format_base_url="https://api.openai.com/v1",
|
||||
anthropic_api_key="anthropic-key",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert service.advanced_model == mock_gpt_instance
|
||||
assert service.base_model == mock_claude_instance
|
||||
mock_custom_chat_openai.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
api_key="openai-key",
|
||||
base_url="https://api.openai.com/v1",
|
||||
temperature=0.0,
|
||||
max_retries=3,
|
||||
)
|
||||
mock_chat_anthropic.assert_called_once_with(
|
||||
model_name="claude-2.1",
|
||||
api_key="anthropic-key",
|
||||
temperature=0.0,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
|
||||
def test_get_openai_format_model(mock_custom_chat_openai):
|
||||
# Exercise
|
||||
get_model(
|
||||
model_name="openrouter/model",
|
||||
openai_format_api_key="openrouter-key",
|
||||
openai_format_base_url="https://openrouter.ai/api/v1",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_custom_chat_openai.assert_called_once_with(
|
||||
model="openrouter/model",
|
||||
api_key="openrouter-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
temperature=0.0,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_claude(mock_chat_anthropic):
|
||||
# Exercise
|
||||
get_model(
|
||||
model_name="claude-2.1",
|
||||
anthropic_api_key="anthropic-key",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_chat_anthropic.assert_called_once_with(
|
||||
model_name="claude-2.1",
|
||||
api_key="anthropic-key",
|
||||
temperature=0.0,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_gemini(mock_chat_google):
|
||||
# Exercise
|
||||
get_model(
|
||||
model_name="gemini-pro",
|
||||
gemini_api_key="gemini-key",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_chat_google.assert_called_once_with(
|
||||
model="gemini-pro",
|
||||
api_key="gemini-key",
|
||||
temperature=0.0,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
|
||||
def test_custom_chat_openai_bind_tools():
|
||||
# Setup
|
||||
model = CustomChatOpenAI(api_key="test-key", max_input_tokens=64000)
|
||||
mock_tools = [Mock()]
|
||||
|
||||
# Exercise
|
||||
with patch("prometheus.chat_models.custom_chat_openai.ChatOpenAI.bind_tools") as mock_bind:
|
||||
model.bind_tools(mock_tools)
|
||||
|
||||
# Verify
|
||||
mock_bind.assert_called_once_with(mock_tools, tool_choice=None, parallel_tool_calls=False)
|
||||
19
prometheus/tests/app/services/test_neo4j_service.py
Normal file
19
prometheus/tests/app/services/test_neo4j_service.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.app.services.neo4j_service import Neo4jService
|
||||
from tests.test_utils.fixtures import neo4j_container_with_kg_fixture # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811
|
||||
neo4j_container, kg = neo4j_container_with_kg_fixture
|
||||
neo4j_service = Neo4jService(
|
||||
neo4j_container.get_connection_url(), neo4j_container.username, neo4j_container.password
|
||||
)
|
||||
assert neo4j_service.neo4j_driver is not None
|
||||
neo4j_service.start()
|
||||
try:
|
||||
await neo4j_service.neo4j_driver.verify_connectivity()
|
||||
except Exception as e:
|
||||
pytest.fail(f"Connection verification failed: {e}")
|
||||
await neo4j_service.close()
|
||||
197
prometheus/tests/app/services/test_repository_service.py
Normal file
197
prometheus/tests/app/services/test_repository_service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from prometheus.app.entity.repository import Repository
|
||||
from prometheus.app.services.database_service import DatabaseService
|
||||
from prometheus.app.services.knowledge_graph_service import KnowledgeGraphService
|
||||
from prometheus.app.services.repository_service import RepositoryService
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg_service():
|
||||
# Mock KnowledgeGraphService; RepositoryService only reads its attributes in other paths
|
||||
kg_service = create_autospec(KnowledgeGraphService, instance=True)
|
||||
kg_service.max_ast_depth = 3
|
||||
kg_service.chunk_size = 1000
|
||||
kg_service.chunk_overlap = 100
|
||||
return kg_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_database_service(postgres_container_fixture): # noqa: F811
|
||||
service = DatabaseService(postgres_container_fixture.get_connection_url())
|
||||
await service.start()
|
||||
yield service
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repository():
|
||||
repo = create_autospec(GitRepository, instance=True)
|
||||
repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo")
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(mock_kg_service, mock_database_service, monkeypatch):
|
||||
working_dir = "/test/working/dir"
|
||||
# Avoid touching the real filesystem when creating the base repo folder
|
||||
monkeypatch.setattr(Path, "mkdir", lambda *args, **kwargs: None)
|
||||
return RepositoryService(
|
||||
kg_service=mock_kg_service,
|
||||
database_service=mock_database_service, # <-- use the correct fixture here
|
||||
working_dir=working_dir,
|
||||
)
|
||||
|
||||
|
||||
async def test_clone_new_github_repo(service, mock_git_repository, monkeypatch):
|
||||
# Arrange
|
||||
test_url = "https://github.com/test/repo"
|
||||
test_commit = "abc123"
|
||||
test_github_token = "test_token"
|
||||
expected_path = Path("/test/working/dir/repositories/repo")
|
||||
|
||||
# Force get_new_playground_path() to return a deterministic path for assertions
|
||||
monkeypatch.setattr(service, "get_new_playground_path", lambda: expected_path)
|
||||
|
||||
# Patch GitRepository so its constructor returns our mock instance
|
||||
with patch(
|
||||
"prometheus.app.services.repository_service.GitRepository",
|
||||
return_value=mock_git_repository,
|
||||
) as mock_git_class:
|
||||
# Act
|
||||
result_path = await service.clone_github_repo(test_github_token, test_url, test_commit)
|
||||
|
||||
# Assert
|
||||
|
||||
# GitRepository should be instantiated without args (per current implementation)
|
||||
mock_git_class.assert_called_once_with()
|
||||
|
||||
# Ensure the clone method was invoked with correct parameters
|
||||
mock_git_repository.from_clone_repository.assert_called_once_with(
|
||||
test_url, test_github_token, expected_path
|
||||
)
|
||||
|
||||
# Verify the requested commit was checked out
|
||||
mock_git_repository.checkout_commit.assert_called_once_with(test_commit)
|
||||
|
||||
# The returned path should be the working directory of the mocked repo
|
||||
assert result_path == expected_path
|
||||
|
||||
|
||||
def test_get_new_playground_path(service):
|
||||
expected_uuid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
expected_path = service.target_directory / expected_uuid
|
||||
with patch("uuid.uuid4") as mock_uuid:
|
||||
mock_uuid.return_value.hex = expected_uuid
|
||||
result = service.get_new_playground_path()
|
||||
|
||||
assert result == expected_path
|
||||
|
||||
|
||||
def test_clean_repository_removes_dir_and_parent(service, monkeypatch):
|
||||
"""
|
||||
Should call shutil.rmtree on the repository path and remove its parent directory
|
||||
when the path exists.
|
||||
"""
|
||||
repo_path = Path("/tmp/repositories/abc123")
|
||||
repository = Repository(playground_path=str(repo_path))
|
||||
|
||||
# Patch path.exists() to return True
|
||||
monkeypatch.setattr(Path, "exists", lambda self: self == repo_path)
|
||||
|
||||
# Track calls to shutil.rmtree and Path.rmdir
|
||||
removed = {"rmtree": None, "rmdir": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"shutil.rmtree",
|
||||
lambda target: removed.update(rmtree=target),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
Path,
|
||||
"rmdir",
|
||||
lambda self: removed["rmdir"].append(self),
|
||||
)
|
||||
|
||||
service.clean_repository(repository)
|
||||
|
||||
# Assert rmtree called with correct path string
|
||||
assert removed["rmtree"] == str(repo_path)
|
||||
# Assert rmdir called on the parent directory
|
||||
assert repo_path.parent in removed["rmdir"]
|
||||
|
||||
|
||||
def test_clean_repository_skips_when_not_exists(service, monkeypatch):
|
||||
"""
|
||||
Should not call rmtree or rmdir when the repository path does not exist.
|
||||
"""
|
||||
repo_path = Path("/tmp/repositories/abc123")
|
||||
repository = Repository(playground_path=str(repo_path))
|
||||
|
||||
# Path.exists returns False
|
||||
monkeypatch.setattr(Path, "exists", lambda self: False)
|
||||
|
||||
monkeypatch.setattr("shutil.rmtree", lambda target: pytest.fail("rmtree should not be called"))
|
||||
monkeypatch.setattr(Path, "rmdir", lambda self: pytest.fail("rmdir should not be called"))
|
||||
|
||||
# No exception means pass
|
||||
service.clean_repository(repository)
|
||||
|
||||
|
||||
def test_get_repository_returns_git_repo_instance(service):
|
||||
"""
|
||||
Should create a GitRepository, call from_local_repository with the given path,
|
||||
and return the GitRepository instance.
|
||||
"""
|
||||
test_path = "/some/local/path"
|
||||
|
||||
mock_git_repo_instance = create_autospec(GitRepository, instance=True)
|
||||
|
||||
# Patch GitRepository() constructor to return our mock instance
|
||||
with patch(
|
||||
"prometheus.app.services.repository_service.GitRepository",
|
||||
return_value=mock_git_repo_instance,
|
||||
) as mock_git_class:
|
||||
result = service.get_repository(test_path)
|
||||
|
||||
# Verify GitRepository() was called with no args
|
||||
mock_git_class.assert_called_once_with()
|
||||
|
||||
# Verify from_local_repository was called with the correct Path object
|
||||
mock_git_repo_instance.from_local_repository.assert_called_once_with(Path(test_path))
|
||||
|
||||
# Verify the returned object is the same as the mock instance
|
||||
assert result == mock_git_repo_instance
|
||||
|
||||
|
||||
async def test_create_new_repository(service):
|
||||
# Exercise
|
||||
await service.create_new_repository(
|
||||
url="https://github.com/test/repo",
|
||||
commit_id="abc123",
|
||||
playground_path="/tmp/repositories/repo",
|
||||
user_id=None,
|
||||
kg_root_node_id=0,
|
||||
)
|
||||
|
||||
# Verify the object is persisted in the database
|
||||
async with AsyncSession(service.engine) as session:
|
||||
db_obj = await session.get(Repository, 1)
|
||||
assert db_obj is not None
|
||||
assert db_obj.url == "https://github.com/test/repo"
|
||||
assert db_obj.commit_id == "abc123"
|
||||
assert db_obj.playground_path == "/tmp/repositories/repo"
|
||||
assert db_obj.user_id is None
|
||||
assert db_obj.kg_root_node_id == 0
|
||||
|
||||
|
||||
async def test_get_all_repositories(service):
|
||||
# Exercise
|
||||
repos = await service.get_all_repositories()
|
||||
# Verify
|
||||
assert len(repos) == 1
|
||||
50
prometheus/tests/app/services/test_user_service.py
Normal file
50
prometheus/tests/app/services/test_user_service.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from prometheus.app.entity.user import User
|
||||
from prometheus.app.services.database_service import DatabaseService
|
||||
from prometheus.app.services.user_service import UserService
|
||||
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_database_service(postgres_container_fixture): # noqa: F811
|
||||
service = DatabaseService(postgres_container_fixture.get_connection_url())
|
||||
await service.start()
|
||||
yield service
|
||||
await service.close()
|
||||
|
||||
|
||||
async def test_create_superuser(mock_database_service):
|
||||
# Exercise
|
||||
service = UserService(mock_database_service)
|
||||
await service.create_superuser(
|
||||
"testuser", "test@gmail.com", "password123", github_token="gh_token"
|
||||
)
|
||||
|
||||
# Verify
|
||||
async with AsyncSession(service.engine) as session:
|
||||
user = await session.get(User, 1)
|
||||
assert user is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "test@gmail.com"
|
||||
assert user.github_token == "gh_token"
|
||||
|
||||
|
||||
async def test_login(mock_database_service):
|
||||
# Exercise
|
||||
service = UserService(mock_database_service)
|
||||
access_token = await service.login("testuser", "test@gmail.com", "password123")
|
||||
# Verify
|
||||
assert access_token is not None
|
||||
|
||||
|
||||
async def test_set_github_token(mock_database_service):
|
||||
# Exercise
|
||||
service = UserService(mock_database_service)
|
||||
await service.set_github_token(1, "new_gh_token")
|
||||
# Verify
|
||||
async with AsyncSession(service.engine) as session:
|
||||
user = await session.get(User, 1)
|
||||
assert user is not None
|
||||
assert user.github_token == "new_gh_token"
|
||||
28
prometheus/tests/app/test_main.py
Normal file
28
prometheus/tests/app/test_main.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Mock the service dependencies"""
|
||||
mock_service = MagicMock()
|
||||
with patch("prometheus.app.dependencies.initialize_services", return_value=mock_service):
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(mock_dependencies):
|
||||
"""Create a TestClient instance with mocked settings and dependencies"""
|
||||
# Import app here to ensure settings are properly mocked
|
||||
from prometheus.app.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
def test_app_initialization(test_client, mock_dependencies):
|
||||
"""Test that the app initializes correctly with mocked dependencies"""
|
||||
assert test_client.app.state.service is not None
|
||||
assert test_client.app.state.service == mock_dependencies
|
||||
0
prometheus/tests/docker/__init__.py
Normal file
0
prometheus/tests/docker/__init__.py
Normal file
309
prometheus/tests/docker/test_base_container.py
Normal file
309
prometheus/tests/docker/test_base_container.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
|
||||
|
||||
class TestContainer(BaseContainer):
|
||||
"""Concrete implementation of BaseContainer for testing."""
|
||||
|
||||
def get_dockerfile_content(self) -> str:
|
||||
return "FROM python:3.9\nWORKDIR /app\nCOPY . /app/"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project_dir():
|
||||
# Create a temporary directory with some test files
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
yield temp_dir
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_docker_client():
|
||||
with patch.object(BaseContainer, "client", new_callable=Mock) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def container(temp_project_dir, mock_docker_client):
|
||||
container = TestContainer(
|
||||
project_path=temp_project_dir,
|
||||
workdir="/app",
|
||||
build_commands=["pip install -r requirements.txt", "python setup.py build"],
|
||||
test_commands=["pytest tests/"],
|
||||
)
|
||||
container.tag_name = "test_container_tag"
|
||||
return container
|
||||
|
||||
|
||||
def test_get_dockerfile_content(container):
|
||||
"""Test that get_dockerfile_content returns expected content"""
|
||||
dockerfile_content = container.get_dockerfile_content()
|
||||
|
||||
assert "FROM python:3.9" in dockerfile_content
|
||||
assert "WORKDIR /app" in dockerfile_content
|
||||
assert "COPY . /app/" in dockerfile_content
|
||||
|
||||
|
||||
def test_build_docker_image(container, mock_docker_client):
|
||||
"""Test building Docker image"""
|
||||
# Setup mock for api.build to return an iterable of log entries
|
||||
mock_build_logs = [
|
||||
{"stream": "Step 1/3 : FROM python:3.9"},
|
||||
{"stream": "Step 2/3 : WORKDIR /app"},
|
||||
{"stream": "Step 3/3 : COPY . /app/"},
|
||||
{"stream": "Successfully built abc123"},
|
||||
]
|
||||
mock_docker_client.api.build.return_value = iter(mock_build_logs)
|
||||
|
||||
# Execute
|
||||
container.build_docker_image()
|
||||
|
||||
# Verify
|
||||
assert (container.project_path / "prometheus.Dockerfile").exists()
|
||||
mock_docker_client.api.build.assert_called_once_with(
|
||||
path=str(container.project_path),
|
||||
dockerfile="prometheus.Dockerfile",
|
||||
tag=container.tag_name,
|
||||
rm=True,
|
||||
decode=True,
|
||||
)
|
||||
|
||||
|
||||
@patch("prometheus.docker.base_container.pexpect.spawn")
|
||||
def test_start_container(mock_spawn, container, mock_docker_client):
|
||||
"""Test starting Docker container"""
|
||||
# Setup mock for pexpect shell
|
||||
mock_shell = Mock()
|
||||
mock_spawn.return_value = mock_shell
|
||||
mock_shell.expect.return_value = 0 # Simulate successful prompt match
|
||||
|
||||
# Setup mock for docker client
|
||||
mock_containers = Mock()
|
||||
mock_docker_client.containers = mock_containers
|
||||
mock_container = Mock()
|
||||
mock_container.id = "test_container_id"
|
||||
mock_containers.run.return_value = mock_container
|
||||
|
||||
# Execute
|
||||
container.start_container()
|
||||
|
||||
# Verify docker container run was called
|
||||
mock_containers.run.assert_called_once_with(
|
||||
container.tag_name,
|
||||
detach=True,
|
||||
tty=True,
|
||||
network_mode="host",
|
||||
environment={"PYTHONPATH": f"{container.workdir}:$PYTHONPATH"},
|
||||
volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}},
|
||||
)
|
||||
|
||||
# Verify pexpect shell was started
|
||||
mock_spawn.assert_called_once_with(
|
||||
f"docker exec -it {mock_container.id} /bin/bash",
|
||||
encoding="utf-8",
|
||||
timeout=container.timeout,
|
||||
)
|
||||
mock_shell.expect.assert_called()
|
||||
|
||||
|
||||
def test_is_running(container):
|
||||
"""Test is_running status check"""
|
||||
# Test when container is None
|
||||
assert not container.is_running()
|
||||
|
||||
# Test when container exists
|
||||
container.container = Mock()
|
||||
assert container.is_running()
|
||||
|
||||
|
||||
def test_update_files(container, temp_project_dir):
|
||||
"""Test updating files in container"""
|
||||
# Setup
|
||||
container.container = Mock()
|
||||
container.execute_command = Mock()
|
||||
|
||||
# Create test files
|
||||
test_file1 = temp_project_dir / "dir1" / "test1.txt"
|
||||
test_file2 = temp_project_dir / "dir2" / "test2.txt"
|
||||
test_file1.parent.mkdir(parents=True)
|
||||
test_file2.parent.mkdir(parents=True)
|
||||
test_file1.write_text("test1")
|
||||
test_file2.write_text("test2")
|
||||
|
||||
updated_files = [Path("dir1/test1.txt"), Path("dir2/test2.txt")]
|
||||
removed_files = [Path("dir3/old.txt")]
|
||||
|
||||
# Execute
|
||||
container.update_files(temp_project_dir, updated_files, removed_files)
|
||||
|
||||
# Verify
|
||||
container.execute_command.assert_has_calls(
|
||||
[call("rm dir3/old.txt"), call("mkdir -p dir1"), call("mkdir -p dir2")]
|
||||
)
|
||||
assert container.container.put_archive.called
|
||||
|
||||
|
||||
@patch("prometheus.docker.base_container.pexpect.spawn")
|
||||
def test_execute_command(mock_spawn, container):
|
||||
"""Test executing command in container using persistent shell"""
|
||||
# Setup mock shell
|
||||
mock_shell = Mock()
|
||||
mock_spawn.return_value = mock_shell
|
||||
|
||||
# Setup container and shell
|
||||
container.container = Mock()
|
||||
container.container.id = "test_container_id"
|
||||
container.shell = mock_shell
|
||||
mock_shell.isalive.return_value = True
|
||||
|
||||
# Mock the shell interactions
|
||||
mock_shell.match = Mock()
|
||||
mock_shell.match.group.return_value = "0" # Exit code 0
|
||||
mock_shell.before = "test command\ncommand output"
|
||||
|
||||
# Execute
|
||||
result = container.execute_command("test command")
|
||||
|
||||
# Verify shell interactions
|
||||
assert mock_shell.sendline.call_count == 2 # Command + marker command
|
||||
mock_shell.expect.assert_called()
|
||||
|
||||
# The result should contain the cleaned output
|
||||
assert "command output" in result
|
||||
|
||||
|
||||
def test_execute_command_with_mock(container):
|
||||
"""Test executing command with direct mocking"""
|
||||
# Setup - directly mock the execute_command method
|
||||
container.execute_command = Mock(return_value="mocked output")
|
||||
container.container = Mock()
|
||||
|
||||
# Execute
|
||||
result = container.execute_command("test command")
|
||||
|
||||
# Verify
|
||||
container.execute_command.assert_called_once_with("test command")
|
||||
assert result == "mocked output"
|
||||
|
||||
|
||||
def test_reset_repository(container):
|
||||
"""Test container reset repository"""
|
||||
# Setup - Mock the execute_command method
|
||||
container.execute_command = Mock(return_value="Command output")
|
||||
container.container = Mock()
|
||||
|
||||
# Execute
|
||||
container.reset_repository()
|
||||
|
||||
# Verify - Check that execute_command was called twice with the correct commands
|
||||
assert container.execute_command.call_count == 2
|
||||
expected_calls = [call("git reset --hard"), call("git clean -fd")]
|
||||
container.execute_command.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
|
||||
@patch("prometheus.docker.base_container.pexpect.spawn")
|
||||
def test_cleanup(mock_spawn, container, mock_docker_client):
|
||||
"""Test cleanup of container resources"""
|
||||
# Setup
|
||||
mock_container = Mock()
|
||||
container.container = mock_container
|
||||
|
||||
# Setup mock shell
|
||||
mock_shell = Mock()
|
||||
mock_shell.isalive.return_value = True
|
||||
container.shell = mock_shell
|
||||
|
||||
# Execute
|
||||
container.cleanup()
|
||||
|
||||
# Verify shell cleanup
|
||||
mock_shell.close.assert_called_once_with(force=True)
|
||||
|
||||
# Verify container cleanup
|
||||
mock_container.stop.assert_called_once_with(timeout=10)
|
||||
mock_container.remove.assert_called_once_with(force=True)
|
||||
mock_docker_client.images.remove.assert_called_once_with(container.tag_name, force=True)
|
||||
assert not container.project_path.exists()
|
||||
|
||||
|
||||
def test_run_build(container):
|
||||
"""Test that build commands are executed correctly"""
|
||||
container.execute_command = Mock()
|
||||
container.execute_command.side_effect = ["Output 1", "Output 2"]
|
||||
|
||||
build_output = container.run_build()
|
||||
|
||||
# Verify execute_command was called for each build command
|
||||
assert container.execute_command.call_count == 2
|
||||
container.execute_command.assert_any_call("pip install -r requirements.txt")
|
||||
container.execute_command.assert_any_call("python setup.py build")
|
||||
|
||||
# Verify output format
|
||||
expected_output = (
|
||||
"$ pip install -r requirements.txt\nOutput 1\n$ python setup.py build\nOutput 2\n"
|
||||
)
|
||||
assert build_output == expected_output
|
||||
|
||||
|
||||
def test_run_test(container):
|
||||
"""Test that test commands are executed correctly"""
|
||||
container.execute_command = Mock()
|
||||
container.execute_command.return_value = "Test passed"
|
||||
|
||||
test_output = container.run_test()
|
||||
|
||||
# Verify execute_command was called for the test command
|
||||
container.execute_command.assert_called_once_with("pytest tests/")
|
||||
|
||||
# Verify output format
|
||||
expected_output = "$ pytest tests/\nTest passed\n"
|
||||
assert test_output == expected_output
|
||||
|
||||
|
||||
def test_run_build_no_commands(container):
|
||||
"""Test run_build when no build commands are defined"""
|
||||
container.build_commands = None
|
||||
result = container.run_build()
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_run_test_no_commands(container):
|
||||
"""Test run_test when no test commands are defined"""
|
||||
container.test_commands = None
|
||||
result = container.run_test()
|
||||
assert result == ""
|
||||
|
||||
|
||||
@patch("prometheus.docker.base_container.pexpect.spawn")
|
||||
def test_restart_shell_if_needed(mock_spawn, container):
|
||||
"""Test shell restart functionality"""
|
||||
# Setup
|
||||
mock_shell_dead = Mock()
|
||||
mock_shell_dead.isalive.return_value = False
|
||||
|
||||
mock_shell_new = Mock()
|
||||
mock_shell_new.expect.return_value = 0
|
||||
mock_spawn.return_value = mock_shell_new
|
||||
|
||||
container.container = Mock()
|
||||
container.container.id = "test_container_id"
|
||||
container.shell = mock_shell_dead
|
||||
|
||||
# Execute
|
||||
container._restart_shell_if_needed()
|
||||
|
||||
# Verify old shell was closed and new one started
|
||||
mock_shell_dead.close.assert_called_once_with(force=True)
|
||||
mock_spawn.assert_called_once()
|
||||
assert container.shell == mock_shell_new
|
||||
44
prometheus/tests/docker/test_general_container.py
Normal file
44
prometheus/tests/docker/test_general_container.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.docker.general_container import GeneralContainer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project_dir():
|
||||
# Create a temporary directory with some test files
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
yield temp_dir
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def container(temp_project_dir):
|
||||
return GeneralContainer(temp_project_dir)
|
||||
|
||||
|
||||
def test_initialization(container, temp_project_dir):
|
||||
"""Test that the container is initialized correctly"""
|
||||
assert isinstance(container.tag_name, str)
|
||||
assert container.tag_name.startswith("prometheus_general_container_")
|
||||
assert container.project_path != temp_project_dir
|
||||
assert (container.project_path / "test.txt").exists()
|
||||
|
||||
|
||||
def test_get_dockerfile_content(container):
|
||||
dockerfile_content = container.get_dockerfile_content()
|
||||
|
||||
assert dockerfile_content
|
||||
|
||||
assert "FROM ubuntu:24.04" in dockerfile_content
|
||||
assert "WORKDIR /app" in dockerfile_content
|
||||
assert "RUN apt-get update" in dockerfile_content
|
||||
assert "COPY . /app/" in dockerfile_content
|
||||
51
prometheus/tests/docker/test_user_defined_container.py
Normal file
51
prometheus/tests/docker/test_user_defined_container.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.docker.user_defined_container import UserDefinedContainer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project_dir():
|
||||
# Create a temporary directory with some test files
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
yield temp_dir
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def container(temp_project_dir):
|
||||
return UserDefinedContainer(
|
||||
temp_project_dir,
|
||||
"/app",
|
||||
"FROM python:3.9\nWORKDIR /app\nCOPY . /app/",
|
||||
None,
|
||||
["pip install -r requirements.txt", "python setup.py build"],
|
||||
["pytest tests/"],
|
||||
)
|
||||
|
||||
|
||||
def test_initialization(container, temp_project_dir):
|
||||
"""Test that the container is initialized correctly"""
|
||||
assert isinstance(container.tag_name, str)
|
||||
assert container.tag_name.startswith("prometheus_user_defined_container_")
|
||||
assert container.project_path != temp_project_dir
|
||||
assert (container.project_path / "test.txt").exists()
|
||||
|
||||
|
||||
def test_get_dockerfile_content(container):
|
||||
dockerfile_content = container.get_dockerfile_content()
|
||||
|
||||
assert dockerfile_content
|
||||
|
||||
# Check for key elements in the Dockerfile
|
||||
assert "FROM python:3.9" in dockerfile_content
|
||||
assert "WORKDIR /app" in dockerfile_content
|
||||
assert "COPY . /app/" in dockerfile_content
|
||||
0
prometheus/tests/git/__init__.py
Normal file
0
prometheus/tests/git/__init__.py
Normal file
164
prometheus/tests/git/test_git_repository.py
Normal file
164
prometheus/tests/git/test_git_repository.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from tests.test_utils import test_project_paths
|
||||
from tests.test_utils.fixtures import git_repo_fixture # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Test fails on Windows because of cptree in git_repo_fixture",
|
||||
)
|
||||
@pytest.mark.git
|
||||
def test_init_with_https_url(git_repo_fixture): # noqa: F811
|
||||
with mock.patch("git.Repo.clone_from") as mock_clone_from, mock.patch("shutil.rmtree"):
|
||||
repo = git_repo_fixture
|
||||
mock_clone_from.return_value = repo
|
||||
|
||||
access_token = "access_token"
|
||||
https_url = "https://github.com/foo/bar.git"
|
||||
target_directory = test_project_paths.TEST_PROJECT_PATH
|
||||
|
||||
git_repo = GitRepository()
|
||||
git_repo.from_clone_repository(
|
||||
https_url=https_url, target_directory=target_directory, github_access_token=access_token
|
||||
)
|
||||
|
||||
mock_clone_from.assert_called_once_with(
|
||||
f"https://{access_token}@github.com/foo/bar.git", target_directory / "bar"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Test fails on Windows because of cptree in git_repo_fixture",
|
||||
)
|
||||
@pytest.mark.git
|
||||
def test_checkout_commit(git_repo_fixture): # noqa: F811
|
||||
git_repo = GitRepository()
|
||||
git_repo.from_local_repository(git_repo_fixture.working_dir)
|
||||
|
||||
commit_sha = "293551b7bd9572b63018c9ed2bccea0f37726805"
|
||||
assert git_repo.repo.head.commit.hexsha != commit_sha
|
||||
git_repo.checkout_commit(commit_sha)
|
||||
assert git_repo.repo.head.commit.hexsha == commit_sha
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Test fails on Windows because of cptree in git_repo_fixture",
|
||||
)
|
||||
@pytest.mark.git
|
||||
def test_switch_branch(git_repo_fixture): # noqa: F811
|
||||
git_repo = GitRepository()
|
||||
git_repo.from_local_repository(git_repo_fixture.working_dir)
|
||||
|
||||
branch_name = "dev"
|
||||
assert git_repo.repo.active_branch.name != branch_name
|
||||
git_repo.switch_branch(branch_name)
|
||||
assert git_repo.repo.active_branch.name == branch_name
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Test fails on Windows because of cptree in git_repo_fixture",
|
||||
)
|
||||
@pytest.mark.git
|
||||
def test_get_diff(git_repo_fixture): # noqa: F811
|
||||
local_path = Path(git_repo_fixture.working_dir).absolute()
|
||||
test_file = local_path / "test.c"
|
||||
|
||||
# Initialize repository
|
||||
git_repo = GitRepository()
|
||||
git_repo.from_local_repository(local_path)
|
||||
|
||||
# Create a change by modifying test.c
|
||||
original_content = test_file.read_text()
|
||||
new_content = "int main() { return 0; }\n"
|
||||
test_file.write_text(new_content)
|
||||
|
||||
# Get diff without exclusions
|
||||
diff = git_repo.get_diff()
|
||||
assert diff is not None
|
||||
expected_diff = """\
|
||||
diff --git a/test.c b/test.c
|
||||
index 79a1160..76e8197 100644
|
||||
--- a/test.c
|
||||
+++ b/test.c
|
||||
@@ -1,6 +1 @@
|
||||
-#include <stdio.h>
|
||||
-
|
||||
-int main() {
|
||||
- printf("Hello world!");
|
||||
- return 0;
|
||||
-}
|
||||
\ No newline at end of file
|
||||
+int main() { return 0; }
|
||||
"""
|
||||
assert diff == expected_diff
|
||||
|
||||
# Test with excluded files
|
||||
diff_with_exclusion = git_repo.get_diff(excluded_files=["test.c"])
|
||||
assert diff_with_exclusion == ""
|
||||
|
||||
# Cleanup - restore original content
|
||||
test_file.write_text(original_content)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Test fails on Windows because of cptree in git_repo_fixture",
|
||||
)
|
||||
@pytest.mark.git
|
||||
def test_apply_patch(git_repo_fixture): # noqa: F811
|
||||
local_path = Path(git_repo_fixture.working_dir).absolute()
|
||||
|
||||
# Initialize repository
|
||||
git_repo = GitRepository()
|
||||
git_repo.from_local_repository(local_path)
|
||||
|
||||
# Apply a patch that modifies test.c
|
||||
patch = """\
|
||||
diff --git a/test.c b/test.c
|
||||
--- a/test.c
|
||||
+++ b/test.c
|
||||
@@ -1,6 +1,1 @@
|
||||
-#include <stdio.h>
|
||||
-
|
||||
-int main() {
|
||||
- printf("Hello world!");
|
||||
- return 0;
|
||||
-}
|
||||
\ No newline at end of file
|
||||
+int main() { return 0; }
|
||||
\ No newline at end of file
|
||||
|
||||
"""
|
||||
git_repo.apply_patch(patch)
|
||||
# Verify the change
|
||||
test_file = local_path / "test.c"
|
||||
assert test_file.exists()
|
||||
assert test_file.read_text() == "int main() { return 0; }"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Test fails on Windows because of cptree in git_repo_fixture",
|
||||
)
|
||||
@pytest.mark.git
|
||||
def test_remove_repository(git_repo_fixture): # noqa: F811
|
||||
with mock.patch("shutil.rmtree") as mock_rmtree:
|
||||
local_path = git_repo_fixture.working_dir
|
||||
|
||||
git_repo = GitRepository()
|
||||
git_repo.from_local_repository(local_path)
|
||||
assert git_repo.repo is not None
|
||||
|
||||
git_repo.remove_repository()
|
||||
|
||||
mock_rmtree.assert_called_once_with(local_path)
|
||||
assert git_repo.repo is None
|
||||
0
prometheus/tests/graph/__init__.py
Normal file
0
prometheus/tests/graph/__init__.py
Normal file
105
prometheus/tests/graph/test_file_graph_builder.py
Normal file
105
prometheus/tests/graph/test_file_graph_builder.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from prometheus.graph.file_graph_builder import FileGraphBuilder
|
||||
from prometheus.graph.graph_types import (
|
||||
ASTNode,
|
||||
KnowledgeGraphEdgeType,
|
||||
KnowledgeGraphNode,
|
||||
TextNode,
|
||||
)
|
||||
from tests.test_utils import test_project_paths
|
||||
|
||||
|
||||
def test_supports_file():
|
||||
file_graph_builder = FileGraphBuilder(0, 0, 0)
|
||||
|
||||
assert file_graph_builder.supports_file(test_project_paths.C_FILE)
|
||||
assert file_graph_builder.supports_file(test_project_paths.JAVA_FILE)
|
||||
assert file_graph_builder.supports_file(test_project_paths.MD_FILE)
|
||||
assert file_graph_builder.supports_file(test_project_paths.PYTHON_FILE)
|
||||
|
||||
assert file_graph_builder.supports_file(test_project_paths.DUMMY_FILE) is False
|
||||
|
||||
|
||||
def test_build_python_file_graph():
|
||||
file_graph_builder = FileGraphBuilder(1000, 1000, 100)
|
||||
|
||||
parent_kg_node = KnowledgeGraphNode(0, None)
|
||||
next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph(
|
||||
parent_kg_node, test_project_paths.PYTHON_FILE, 0
|
||||
)
|
||||
|
||||
assert next_node_id == 11
|
||||
assert len(kg_nodes) == 11
|
||||
assert len(kg_edges) == 11
|
||||
|
||||
# Test if some of the nodes exists
|
||||
argument_list_ast_node = ASTNode(
|
||||
type="argument_list", start_line=1, end_line=1, text='("Hello world!")'
|
||||
)
|
||||
string_ast_node = ASTNode(type="string", start_line=1, end_line=1, text='"Hello world!"')
|
||||
|
||||
found_argument_list_ast_node = False
|
||||
for kg_node in kg_nodes:
|
||||
if kg_node.node == argument_list_ast_node:
|
||||
found_argument_list_ast_node = True
|
||||
assert found_argument_list_ast_node
|
||||
|
||||
found_string_ast_node = False
|
||||
for kg_node in kg_nodes:
|
||||
if kg_node.node == string_ast_node:
|
||||
found_string_ast_node = True
|
||||
assert found_string_ast_node
|
||||
|
||||
# Test if some of the edges exists
|
||||
found_edge = False
|
||||
for kg_edge in kg_edges:
|
||||
if (
|
||||
kg_edge.source.node == argument_list_ast_node
|
||||
and kg_edge.target.node == string_ast_node
|
||||
and kg_edge.type == KnowledgeGraphEdgeType.parent_of
|
||||
):
|
||||
found_edge = True
|
||||
assert found_edge
|
||||
|
||||
|
||||
def test_build_text_file_graph():
|
||||
file_graph_builder = FileGraphBuilder(1000, 100, 10)
|
||||
|
||||
parent_kg_node = KnowledgeGraphNode(0, None)
|
||||
next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph(
|
||||
parent_kg_node, test_project_paths.MD_FILE, 0
|
||||
)
|
||||
|
||||
assert next_node_id == 2
|
||||
assert len(kg_nodes) == 2
|
||||
assert len(kg_edges) == 3
|
||||
|
||||
# Test if some of the nodes exists
|
||||
text_node_1 = TextNode(
|
||||
text="# A\n\nText under header A.\n\n## B\n\nText under header B.\n\n## C\n\nText under header C.\n\n### D",
|
||||
start_line=1,
|
||||
end_line=13,
|
||||
)
|
||||
text_node_2 = TextNode(text="### D\n\nText under header D.", start_line=13, end_line=15)
|
||||
|
||||
found_text_node_1 = False
|
||||
for kg_node in kg_nodes:
|
||||
if kg_node.node == text_node_1:
|
||||
found_text_node_1 = True
|
||||
assert found_text_node_1
|
||||
|
||||
found_text_node_2 = False
|
||||
for kg_node in kg_nodes:
|
||||
if kg_node.node == text_node_2:
|
||||
found_text_node_2 = True
|
||||
assert found_text_node_2
|
||||
|
||||
# Test if some of the edges exists
|
||||
found_edge = False
|
||||
for kg_edge in kg_edges:
|
||||
if (
|
||||
kg_edge.source.node == text_node_1
|
||||
and kg_edge.target.node == text_node_2
|
||||
and kg_edge.type == KnowledgeGraphEdgeType.next_chunk
|
||||
):
|
||||
found_edge = True
|
||||
assert found_edge
|
||||
272
prometheus/tests/graph/test_graph_types.py
Normal file
272
prometheus/tests/graph/test_graph_types.py
Normal file
@@ -0,0 +1,272 @@
|
||||
from prometheus.graph.graph_types import (
|
||||
ASTNode,
|
||||
FileNode,
|
||||
KnowledgeGraphEdge,
|
||||
KnowledgeGraphEdgeType,
|
||||
KnowledgeGraphNode,
|
||||
Neo4jASTNode,
|
||||
Neo4jFileNode,
|
||||
Neo4jTextNode,
|
||||
TextNode,
|
||||
)
|
||||
|
||||
|
||||
def test_to_neo4j_file_node():
|
||||
basename = "foo"
|
||||
relative_path = "foo/bar/baz.py"
|
||||
node_id = 1
|
||||
|
||||
file_node = FileNode(basename, relative_path)
|
||||
knowldege_graph_node = KnowledgeGraphNode(node_id, file_node)
|
||||
neo4j_file_node = knowldege_graph_node.to_neo4j_node()
|
||||
|
||||
assert isinstance(neo4j_file_node, dict)
|
||||
|
||||
assert "node_id" in neo4j_file_node
|
||||
assert "basename" in neo4j_file_node
|
||||
assert "relative_path" in neo4j_file_node
|
||||
|
||||
assert neo4j_file_node["node_id"] == node_id
|
||||
assert neo4j_file_node["basename"] == basename
|
||||
assert neo4j_file_node["relative_path"] == relative_path
|
||||
|
||||
|
||||
def test_to_neo4j_ast_node():
|
||||
type = "method_declaration"
|
||||
start_line = 1
|
||||
end_line = 5
|
||||
text = "print('Hello world')"
|
||||
node_id = 1
|
||||
|
||||
ast_node = ASTNode(type, start_line, end_line, text)
|
||||
knowldege_graph_node = KnowledgeGraphNode(node_id, ast_node)
|
||||
neo4j_ast_node = knowldege_graph_node.to_neo4j_node()
|
||||
|
||||
assert isinstance(neo4j_ast_node, dict)
|
||||
|
||||
assert "node_id" in neo4j_ast_node
|
||||
assert "type" in neo4j_ast_node
|
||||
assert "start_line" in neo4j_ast_node
|
||||
assert "end_line" in neo4j_ast_node
|
||||
assert "text" in neo4j_ast_node
|
||||
|
||||
assert neo4j_ast_node["node_id"] == node_id
|
||||
assert neo4j_ast_node["type"] == type
|
||||
assert neo4j_ast_node["start_line"] == start_line
|
||||
assert neo4j_ast_node["end_line"] == end_line
|
||||
assert neo4j_ast_node["text"] == text
|
||||
|
||||
|
||||
def test_to_neo4j_text_node():
|
||||
text = "Hello world"
|
||||
node_id = 1
|
||||
start_line = 1
|
||||
end_line = 1
|
||||
|
||||
text_node = TextNode(text, start_line, end_line)
|
||||
knowldege_graph_node = KnowledgeGraphNode(node_id, text_node)
|
||||
neo4j_text_node = knowldege_graph_node.to_neo4j_node()
|
||||
|
||||
assert isinstance(neo4j_text_node, dict)
|
||||
|
||||
assert "node_id" in neo4j_text_node
|
||||
assert "text" in neo4j_text_node
|
||||
assert "start_line" in neo4j_text_node
|
||||
assert "end_line" in neo4j_text_node
|
||||
|
||||
assert neo4j_text_node["node_id"] == node_id
|
||||
assert neo4j_text_node["text"] == text
|
||||
assert neo4j_text_node["start_line"] == start_line
|
||||
assert neo4j_text_node["end_line"] == end_line
|
||||
|
||||
|
||||
def test_to_neo4j_has_file_edge():
|
||||
source_basename = "source"
|
||||
source_relative_path = "foo/bar/source.py"
|
||||
source_node_id = 1
|
||||
target_basename = "target"
|
||||
target_relative_path = "foo/bar/target.py"
|
||||
target_node_id = 10
|
||||
|
||||
source_file_node = FileNode(source_basename, source_relative_path)
|
||||
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node)
|
||||
target_file_node = FileNode(target_basename, target_relative_path)
|
||||
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_file_node)
|
||||
knowledge_graph_edge = KnowledgeGraphEdge(
|
||||
source_knowledge_graph_node,
|
||||
target_knowledge_graph_node,
|
||||
KnowledgeGraphEdgeType.has_file,
|
||||
)
|
||||
neo4j_has_file_edge = knowledge_graph_edge.to_neo4j_edge()
|
||||
|
||||
assert isinstance(neo4j_has_file_edge, dict)
|
||||
|
||||
assert "source" in neo4j_has_file_edge
|
||||
assert "target" in neo4j_has_file_edge
|
||||
|
||||
assert neo4j_has_file_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
|
||||
assert neo4j_has_file_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
|
||||
|
||||
|
||||
def test_to_neo4j_has_ast_edge():
|
||||
source_basename = "source"
|
||||
source_relative_path = "foo/bar/source.py"
|
||||
source_node_id = 1
|
||||
target_type = "return_statement"
|
||||
target_start_line = 7
|
||||
target_end_line = 9
|
||||
target_text = "return True"
|
||||
target_node_id = 10
|
||||
|
||||
source_file_node = FileNode(source_basename, source_relative_path)
|
||||
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node)
|
||||
target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text)
|
||||
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node)
|
||||
knowledge_graph_edge = KnowledgeGraphEdge(
|
||||
source_knowledge_graph_node,
|
||||
target_knowledge_graph_node,
|
||||
KnowledgeGraphEdgeType.has_ast,
|
||||
)
|
||||
neo4j_has_ast_edge = knowledge_graph_edge.to_neo4j_edge()
|
||||
|
||||
assert isinstance(neo4j_has_ast_edge, dict)
|
||||
|
||||
assert "source" in neo4j_has_ast_edge
|
||||
assert "target" in neo4j_has_ast_edge
|
||||
|
||||
assert neo4j_has_ast_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
|
||||
assert neo4j_has_ast_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
|
||||
|
||||
|
||||
def test_to_neo4j_parent_of_edge():
|
||||
source_type = "method_declaration"
|
||||
source_start_line = 1
|
||||
source_end_line = 5
|
||||
source_text = "print('Hello world')"
|
||||
source_node_id = 1
|
||||
target_type = "return_statement"
|
||||
target_start_line = 7
|
||||
target_end_line = 9
|
||||
target_text = "return True"
|
||||
target_node_id = 10
|
||||
|
||||
source_ast_node = ASTNode(source_type, source_start_line, source_end_line, source_text)
|
||||
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_ast_node)
|
||||
target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text)
|
||||
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node)
|
||||
knowledge_graph_edge = KnowledgeGraphEdge(
|
||||
source_knowledge_graph_node,
|
||||
target_knowledge_graph_node,
|
||||
KnowledgeGraphEdgeType.parent_of,
|
||||
)
|
||||
neo4j_parent_of_edge = knowledge_graph_edge.to_neo4j_edge()
|
||||
|
||||
assert isinstance(neo4j_parent_of_edge, dict)
|
||||
|
||||
assert "source" in neo4j_parent_of_edge
|
||||
assert "target" in neo4j_parent_of_edge
|
||||
|
||||
assert neo4j_parent_of_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
|
||||
assert neo4j_parent_of_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
|
||||
|
||||
|
||||
def test_to_neo4j_has_text_edge():
|
||||
source_basename = "source"
|
||||
source_relative_path = "foo/bar/source.py"
|
||||
source_node_id = 1
|
||||
target_text = "Hello world"
|
||||
target_start_line = 1
|
||||
target_end_line = 1
|
||||
target_node_id = 10
|
||||
|
||||
source_file_node = FileNode(source_basename, source_relative_path)
|
||||
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node)
|
||||
target_text_node = TextNode(target_text, target_start_line, target_end_line)
|
||||
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node)
|
||||
knowledge_graph_edge = KnowledgeGraphEdge(
|
||||
source_knowledge_graph_node,
|
||||
target_knowledge_graph_node,
|
||||
KnowledgeGraphEdgeType.has_text,
|
||||
)
|
||||
neo4j_has_text_edge = knowledge_graph_edge.to_neo4j_edge()
|
||||
|
||||
assert isinstance(neo4j_has_text_edge, dict)
|
||||
|
||||
assert "source" in neo4j_has_text_edge
|
||||
assert "target" in neo4j_has_text_edge
|
||||
|
||||
assert neo4j_has_text_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
|
||||
assert neo4j_has_text_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
|
||||
|
||||
|
||||
def test_to_neo4j_next_chunk_edge():
|
||||
source_text = "Hello"
|
||||
start_line = 1
|
||||
end_line = 1
|
||||
source_node_id = 1
|
||||
target_text = "world"
|
||||
target_start_line = 1
|
||||
target_end_line = 1
|
||||
target_node_id = 10
|
||||
|
||||
source_text_node = TextNode(source_text, start_line, end_line)
|
||||
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_text_node)
|
||||
target_text_node = TextNode(target_text, target_start_line, target_end_line)
|
||||
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node)
|
||||
knowledge_graph_edge = KnowledgeGraphEdge(
|
||||
source_knowledge_graph_node,
|
||||
target_knowledge_graph_node,
|
||||
KnowledgeGraphEdgeType.has_text,
|
||||
)
|
||||
neo4j_next_chunk_edge = knowledge_graph_edge.to_neo4j_edge()
|
||||
|
||||
assert isinstance(neo4j_next_chunk_edge, dict)
|
||||
|
||||
assert "source" in neo4j_next_chunk_edge
|
||||
assert "target" in neo4j_next_chunk_edge
|
||||
|
||||
assert neo4j_next_chunk_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
|
||||
assert neo4j_next_chunk_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
|
||||
|
||||
|
||||
def test_from_neo4j_file_node():
|
||||
node_id = 10
|
||||
basename = "foo.py"
|
||||
relative_path = "bar/baz/foo.py"
|
||||
neo4j_file_node = Neo4jFileNode(node_id=node_id, basename=basename, relative_path=relative_path)
|
||||
knowledge_graph_node = KnowledgeGraphNode.from_neo4j_file_node(neo4j_file_node)
|
||||
|
||||
expected_file_node = FileNode(basename, relative_path)
|
||||
expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_file_node)
|
||||
assert knowledge_graph_node == expected_knowledge_graph_node
|
||||
|
||||
|
||||
def test_from_neo4j_ast_node():
|
||||
node_id = 15
|
||||
type = "string_literal"
|
||||
start_line = 5
|
||||
end_line = 6
|
||||
text = '"hello world"'
|
||||
neo4j_ast_node = Neo4jASTNode(
|
||||
node_id=node_id, type=type, start_line=start_line, end_line=end_line, text=text
|
||||
)
|
||||
knowledge_graph_node = KnowledgeGraphNode.from_neo4j_ast_node(neo4j_ast_node)
|
||||
|
||||
expected_ast_node = ASTNode(type, start_line, end_line, text)
|
||||
expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_ast_node)
|
||||
assert knowledge_graph_node == expected_knowledge_graph_node
|
||||
|
||||
|
||||
def test_from_neo4j_text_node():
|
||||
node_id = 20
|
||||
text = "hello world"
|
||||
start_line = 1
|
||||
end_line = 1
|
||||
neo4j_text_node = Neo4jTextNode(
|
||||
node_id=node_id, text=text, start_line=start_line, end_line=end_line
|
||||
)
|
||||
knowledge_graph_node = KnowledgeGraphNode.from_neo4j_text_node(neo4j_text_node)
|
||||
|
||||
expected_text_node = TextNode(text, start_line, end_line)
|
||||
expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_text_node)
|
||||
assert knowledge_graph_node == expected_knowledge_graph_node
|
||||
92
prometheus/tests/graph/test_knowledge_graph.py
Normal file
92
prometheus/tests/graph/test_knowledge_graph.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.app.services.neo4j_service import Neo4jService
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
|
||||
from tests.test_utils import test_project_paths
|
||||
from tests.test_utils.fixtures import (
|
||||
NEO4J_PASSWORD,
|
||||
NEO4J_USERNAME,
|
||||
neo4j_container_with_kg_fixture, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811
|
||||
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
|
||||
neo4j_container, kg = neo4j_container_with_kg_fixture
|
||||
service = Neo4jService(neo4j_container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
|
||||
service.start()
|
||||
yield service, kg
|
||||
await service.close()
|
||||
|
||||
|
||||
async def test_build_graph():
|
||||
knowledge_graph = KnowledgeGraph(1, 1000, 100, 0)
|
||||
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
|
||||
assert knowledge_graph._next_node_id == 15
|
||||
# 7 FileNode
|
||||
# 84 ASTnode
|
||||
# 2 TextNode
|
||||
assert len(knowledge_graph._knowledge_graph_nodes) == 15
|
||||
assert len(knowledge_graph._knowledge_graph_edges) == 14
|
||||
|
||||
assert len(knowledge_graph.get_file_nodes()) == 7
|
||||
assert len(knowledge_graph.get_ast_nodes()) == 7
|
||||
assert len(knowledge_graph.get_text_nodes()) == 1
|
||||
assert len(knowledge_graph.get_parent_of_edges()) == 4
|
||||
assert len(knowledge_graph.get_has_file_edges()) == 6
|
||||
assert len(knowledge_graph.get_has_ast_edges()) == 3
|
||||
assert len(knowledge_graph.get_has_text_edges()) == 1
|
||||
assert len(knowledge_graph.get_next_chunk_edges()) == 0
|
||||
|
||||
|
||||
async def test_get_file_tree():
|
||||
knowledge_graph = KnowledgeGraph(1000, 1000, 100, 0)
|
||||
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
file_tree = knowledge_graph.get_file_tree()
|
||||
expected_file_tree = """\
|
||||
test_project
|
||||
├── bar
|
||||
| ├── test.java
|
||||
| └── test.py
|
||||
├── foo
|
||||
| └── test.md
|
||||
└── test.c"""
|
||||
assert file_tree == expected_file_tree
|
||||
|
||||
|
||||
async def test_get_file_tree_depth_one():
|
||||
knowledge_graph = KnowledgeGraph(1000, 1000, 100, 0)
|
||||
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
file_tree = knowledge_graph.get_file_tree(max_depth=1)
|
||||
expected_file_tree = """\
|
||||
test_project
|
||||
├── bar
|
||||
├── foo
|
||||
└── test.c"""
|
||||
assert file_tree == expected_file_tree
|
||||
|
||||
|
||||
async def test_get_file_tree_depth_two_max_seven_lines():
|
||||
knowledge_graph = KnowledgeGraph(1000, 1000, 100, 0)
|
||||
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
file_tree = knowledge_graph.get_file_tree(max_depth=2, max_lines=6)
|
||||
expected_file_tree = """\
|
||||
test_project
|
||||
├── bar
|
||||
| ├── test.java
|
||||
| └── test.py
|
||||
├── foo
|
||||
| └── test.md"""
|
||||
assert file_tree == expected_file_tree
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_from_neo4j(mock_neo4j_service):
|
||||
service, kg = mock_neo4j_service
|
||||
handler = KnowledgeGraphHandler(service.neo4j_driver, 100)
|
||||
read_kg = await handler.read_knowledge_graph(0, 1000, 100, 10)
|
||||
|
||||
assert read_kg == kg
|
||||
0
prometheus/tests/lang_graph/__init__.py
Normal file
0
prometheus/tests/lang_graph/__init__.py
Normal file
0
prometheus/tests/lang_graph/graphs/__init__.py
Normal file
0
prometheus/tests/lang_graph/graphs/__init__.py
Normal file
60
prometheus/tests/lang_graph/graphs/test_issue_graph.py
Normal file
60
prometheus/tests/lang_graph/graphs/test_issue_graph.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.graphs.issue_graph import IssueGraph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_advanced_model():
|
||||
return Mock(spec=BaseChatModel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_base_model():
|
||||
return Mock(spec=BaseChatModel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
|
||||
kg.root_node_id = 0
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.playground_path = "mock/playground/path"
|
||||
return git_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
def test_issue_graph_basic_initialization(
|
||||
mock_advanced_model,
|
||||
mock_base_model,
|
||||
mock_kg,
|
||||
mock_git_repo,
|
||||
mock_container,
|
||||
):
|
||||
"""Test that IssueGraph initializes correctly with basic components."""
|
||||
graph = IssueGraph(
|
||||
advanced_model=mock_advanced_model,
|
||||
base_model=mock_base_model,
|
||||
kg=mock_kg,
|
||||
git_repo=mock_git_repo,
|
||||
container=mock_container,
|
||||
repository_id=1,
|
||||
)
|
||||
|
||||
assert graph.graph is not None
|
||||
assert graph.git_repo == mock_git_repo
|
||||
0
prometheus/tests/lang_graph/nodes/__init__.py
Normal file
0
prometheus/tests/lang_graph/nodes/__init__.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.lang_graph.nodes.add_context_refined_query_message_node import (
|
||||
AddContextRefinedQueryMessageNode,
|
||||
)
|
||||
from prometheus.models.query import Query
|
||||
|
||||
|
||||
def test_add_context_refined_query_message_node_with_all_fields():
|
||||
"""Test node with all fields populated in refined_query."""
|
||||
node = AddContextRefinedQueryMessageNode()
|
||||
|
||||
refined_query = Query(
|
||||
essential_query="Find all authentication logic",
|
||||
extra_requirements="Include error handling and validation",
|
||||
purpose="Security audit",
|
||||
)
|
||||
|
||||
state = {"refined_query": refined_query}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "context_provider_messages" in result
|
||||
assert len(result["context_provider_messages"]) == 1
|
||||
assert isinstance(result["context_provider_messages"][0], HumanMessage)
|
||||
|
||||
message_content = result["context_provider_messages"][0].content
|
||||
assert "Essential query: Find all authentication logic" in message_content
|
||||
assert "Extra requirements: Include error handling and validation" in message_content
|
||||
assert "Purpose: Security audit" in message_content
|
||||
|
||||
|
||||
def test_add_context_refined_query_message_node_essential_query_only():
|
||||
"""Test node with only essential_query populated."""
|
||||
node = AddContextRefinedQueryMessageNode()
|
||||
|
||||
refined_query = Query(
|
||||
essential_query="Locate the main entry point",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
)
|
||||
|
||||
state = {"refined_query": refined_query}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "context_provider_messages" in result
|
||||
assert len(result["context_provider_messages"]) == 1
|
||||
|
||||
message_content = result["context_provider_messages"][0].content
|
||||
assert "Essential query: Locate the main entry point" in message_content
|
||||
assert "Extra requirements:" not in message_content
|
||||
assert "Purpose:" not in message_content
|
||||
|
||||
|
||||
def test_add_context_refined_query_message_node_with_extra_requirements_only():
|
||||
"""Test node with essential_query and extra_requirements only."""
|
||||
node = AddContextRefinedQueryMessageNode()
|
||||
|
||||
refined_query = Query(
|
||||
essential_query="Find database queries",
|
||||
extra_requirements="Focus on SQL injection vulnerabilities",
|
||||
purpose="",
|
||||
)
|
||||
|
||||
state = {"refined_query": refined_query}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "context_provider_messages" in result
|
||||
message_content = result["context_provider_messages"][0].content
|
||||
|
||||
assert "Essential query: Find database queries" in message_content
|
||||
assert "Extra requirements: Focus on SQL injection vulnerabilities" in message_content
|
||||
assert "Purpose:" not in message_content
|
||||
|
||||
|
||||
def test_add_context_refined_query_message_node_with_purpose_only():
|
||||
"""Test node with essential_query and purpose only."""
|
||||
node = AddContextRefinedQueryMessageNode()
|
||||
|
||||
refined_query = Query(
|
||||
essential_query="Identify all API endpoints",
|
||||
extra_requirements="",
|
||||
purpose="Documentation generation",
|
||||
)
|
||||
|
||||
state = {"refined_query": refined_query}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "context_provider_messages" in result
|
||||
message_content = result["context_provider_messages"][0].content
|
||||
|
||||
assert "Essential query: Identify all API endpoints" in message_content
|
||||
assert "Extra requirements:" not in message_content
|
||||
assert "Purpose: Documentation generation" in message_content
|
||||
|
||||
|
||||
def test_add_context_refined_query_message_node_returns_list():
|
||||
"""Test that node returns a list with exactly one HumanMessage."""
|
||||
node = AddContextRefinedQueryMessageNode()
|
||||
|
||||
refined_query = Query(
|
||||
essential_query="Test query",
|
||||
extra_requirements="Test requirements",
|
||||
purpose="Test purpose",
|
||||
)
|
||||
|
||||
state = {"refined_query": refined_query}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert isinstance(result["context_provider_messages"], list)
|
||||
assert len(result["context_provider_messages"]) == 1
|
||||
assert isinstance(result["context_provider_messages"][0], HumanMessage)
|
||||
|
||||
|
||||
def test_add_context_refined_query_message_node_message_format():
|
||||
"""Test the exact format of the constructed message."""
|
||||
node = AddContextRefinedQueryMessageNode()
|
||||
|
||||
refined_query = Query(
|
||||
essential_query="Query text",
|
||||
extra_requirements="Requirements text",
|
||||
purpose="Purpose text",
|
||||
)
|
||||
|
||||
state = {"refined_query": refined_query}
|
||||
|
||||
result = node(state)
|
||||
|
||||
expected_content = (
|
||||
"Essential query: Query text\nExtra requirements: Requirements text\nPurpose: Purpose text"
|
||||
)
|
||||
|
||||
assert result["context_provider_messages"][0].content == expected_content
|
||||
@@ -0,0 +1,52 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.lang_graph.nodes.bug_fix_verify_node import BugFixVerifyNode
|
||||
from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerificationState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_state():
|
||||
return BugFixVerificationState(
|
||||
{
|
||||
"reproduced_bug_file": "test_bug.py",
|
||||
"reproduced_bug_commands": ["python test_bug.py", "./run_test.sh"],
|
||||
"bug_fix_verify_messages": [AIMessage(content="Previous verification result")],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(responses=["Test execution completed"])
|
||||
|
||||
|
||||
def test_format_human_message(mock_container, fake_llm, test_state):
|
||||
"""Test human message formatting."""
|
||||
node = BugFixVerifyNode(fake_llm, mock_container)
|
||||
message = node.format_human_message(test_state)
|
||||
|
||||
assert isinstance(message, HumanMessage)
|
||||
assert "test_bug.py" in message.content
|
||||
assert "python test_bug.py" in message.content
|
||||
assert "./run_test.sh" in message.content
|
||||
|
||||
|
||||
def test_call_method(mock_container, fake_llm, test_state):
|
||||
"""Test the __call__ method execution."""
|
||||
node = BugFixVerifyNode(fake_llm, mock_container)
|
||||
|
||||
result = node(test_state)
|
||||
|
||||
assert "bug_fix_verify_messages" in result
|
||||
assert len(result["bug_fix_verify_messages"]) == 1
|
||||
assert result["bug_fix_verify_messages"][0].content == "Test execution completed"
|
||||
@@ -0,0 +1,74 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.lang_graph.nodes.bug_reproducing_execute_node import BugReproducingExecuteNode
|
||||
from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_state():
|
||||
return BugReproductionState(
|
||||
{
|
||||
"issue_title": "Test Bug",
|
||||
"issue_body": "Bug description",
|
||||
"issue_comments": ["Comment 1", "Comment 2"],
|
||||
"bug_context": "Context of the bug",
|
||||
"bug_reproducing_write_messages": [AIMessage(content="patch")],
|
||||
"bug_reproducing_file_messages": [AIMessage(content="path")],
|
||||
"bug_reproducing_execute_messages": [],
|
||||
"bug_reproducing_patch": "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_format_human_message_with_test_commands(mock_container, test_state):
|
||||
"""Test message formatting with provided test commands."""
|
||||
fake_llm = FakeListChatWithToolsModel(responses=["test"])
|
||||
test_commands = ["pytest", "python -m unittest"]
|
||||
node = BugReproducingExecuteNode(fake_llm, mock_container, test_commands)
|
||||
|
||||
message = node.format_human_message(test_state, "/foo/bar/test.py")
|
||||
|
||||
assert isinstance(message, HumanMessage)
|
||||
assert "Test Bug" in message.content
|
||||
assert "Bug description" in message.content
|
||||
assert "Comment 1" in message.content
|
||||
assert "Comment 2" in message.content
|
||||
assert "pytest" in message.content
|
||||
assert "/foo/bar/test.py" in message.content
|
||||
|
||||
|
||||
def test_format_human_message_without_test_commands(mock_container, test_state):
|
||||
"""Test message formatting without test commands."""
|
||||
fake_llm = FakeListChatWithToolsModel(responses=["test"])
|
||||
node = BugReproducingExecuteNode(fake_llm, mock_container)
|
||||
|
||||
message = node.format_human_message(test_state, "/foo/bar/test.py")
|
||||
|
||||
assert isinstance(message, HumanMessage)
|
||||
assert "Test Bug" in message.content
|
||||
assert "Bug description" in message.content
|
||||
assert "User provided test commands:\n" in message.content
|
||||
assert "/foo/bar/test.py" in message.content
|
||||
|
||||
|
||||
def test_call_method(mock_container, test_state):
|
||||
"""Test the __call__ method execution."""
|
||||
fake_response = "Test execution completed"
|
||||
fake_llm = FakeListChatWithToolsModel(responses=[fake_response])
|
||||
node = BugReproducingExecuteNode(fake_llm, mock_container)
|
||||
|
||||
result = node(test_state)
|
||||
|
||||
assert "bug_reproducing_execute_messages" in result
|
||||
assert len(result["bug_reproducing_execute_messages"]) == 1
|
||||
assert result["bug_reproducing_execute_messages"][0].content == fake_response
|
||||
@@ -0,0 +1,68 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.nodes.bug_reproducing_file_node import BugReproducingFileNode
|
||||
from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
kg.get_file_tree.return_value = "test_dir/\n test_file.py"
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(responses=["test_output.py"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return BugReproductionState(
|
||||
issue_title="mock issue title",
|
||||
issue_body="mock issue body",
|
||||
issue_comments=[],
|
||||
max_refined_query_loop=3,
|
||||
bug_reproducing_query="mock query",
|
||||
bug_reproducing_context=[],
|
||||
bug_reproducing_patch="",
|
||||
bug_reproducing_write_messages=[AIMessage(content="def test_bug():\n assert 1 == 2")],
|
||||
bug_reproducing_file_messages=[],
|
||||
bug_reproducing_execute_messages=[],
|
||||
reproduced_bug=False,
|
||||
reproduced_bug_failure_log="",
|
||||
reproduced_bug_file="",
|
||||
reproduced_bug_commands=[],
|
||||
)
|
||||
|
||||
|
||||
def test_initialization(mock_kg, fake_llm):
|
||||
"""Test basic initialization of BugReproducingFileNode."""
|
||||
node = BugReproducingFileNode(fake_llm, mock_kg, "test/path")
|
||||
|
||||
assert isinstance(node.system_prompt, SystemMessage)
|
||||
assert len(node.tools) == 2 # read_file, create_file
|
||||
|
||||
|
||||
def test_format_human_message(mock_kg, fake_llm, basic_state):
|
||||
"""Test human message formatting with bug file."""
|
||||
node = BugReproducingFileNode(fake_llm, mock_kg, "test/path")
|
||||
message = node.format_human_message(basic_state)
|
||||
|
||||
assert isinstance(message, HumanMessage)
|
||||
assert "def test_bug():" in message.content
|
||||
|
||||
|
||||
def test_call_method(mock_kg, fake_llm, basic_state):
|
||||
"""Test the __call__ method execution."""
|
||||
node = BugReproducingFileNode(fake_llm, mock_kg, "test/path")
|
||||
result = node(basic_state)
|
||||
|
||||
assert "bug_reproducing_file_messages" in result
|
||||
assert len(result["bug_reproducing_file_messages"]) == 1
|
||||
assert result["bug_reproducing_file_messages"][0].content == "test_output.py"
|
||||
@@ -0,0 +1,49 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.nodes.bug_reproducing_write_node import BugReproducingWriteNode
|
||||
from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState
|
||||
from tests.test_utils.fixtures import temp_test_dir # noqa: F401
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_state():
|
||||
return BugReproductionState(
|
||||
issue_title="Test Bug",
|
||||
issue_body="Bug description",
|
||||
issue_comments=[{"user1": "Comment 1"}, {"user2": "Comment 2"}],
|
||||
max_refined_query_loop=3,
|
||||
bug_reproducing_query="mock query",
|
||||
bug_reproducing_context=[],
|
||||
bug_reproducing_write_messages=[HumanMessage("assert x == 10")],
|
||||
bug_reproducing_file_messages=[],
|
||||
bug_reproducing_execute_messages=[],
|
||||
bug_reproducing_patch="",
|
||||
reproduced_bug=False,
|
||||
reproduced_bug_failure_log="Test failure log",
|
||||
reproduced_bug_file="test/file.py",
|
||||
reproduced_bug_commands=[],
|
||||
)
|
||||
|
||||
|
||||
def test_call_method(mock_kg, test_state, temp_test_dir): # noqa: F811
|
||||
"""Test the __call__ method execution."""
|
||||
fake_response = "Created test file"
|
||||
fake_llm = FakeListChatWithToolsModel(responses=[fake_response])
|
||||
node = BugReproducingWriteNode(fake_llm, temp_test_dir, mock_kg)
|
||||
|
||||
result = node(test_state)
|
||||
|
||||
assert "bug_reproducing_write_messages" in result
|
||||
assert len(result["bug_reproducing_write_messages"]) == 1
|
||||
assert result["bug_reproducing_write_messages"][0].content == fake_response
|
||||
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.nodes.context_provider_node import ContextProviderNode
|
||||
from tests.test_utils import test_project_paths
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def knowledge_graph_fixture():
|
||||
kg = KnowledgeGraph(1000, 100, 10, 0)
|
||||
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_context_provider_node_basic_query(knowledge_graph_fixture):
|
||||
"""Test basic query handling with the ContextProviderNode."""
|
||||
fake_response = "Fake response"
|
||||
fake_llm = FakeListChatWithToolsModel(responses=[fake_response])
|
||||
node = ContextProviderNode(
|
||||
model=fake_llm,
|
||||
kg=knowledge_graph_fixture,
|
||||
local_path=test_project_paths.TEST_PROJECT_PATH,
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
AIMessage(content="This code handles file processing"),
|
||||
ToolMessage(content="Found implementation in utils.py", tool_call_id="test_tool_call_1"),
|
||||
]
|
||||
test_state = {
|
||||
"original_query": "How does the error handling work?",
|
||||
"context_provider_messages": test_messages,
|
||||
}
|
||||
|
||||
result = node(test_state)
|
||||
|
||||
assert "context_provider_messages" in result
|
||||
assert len(result["context_provider_messages"]) == 1
|
||||
assert result["context_provider_messages"][0].content == fake_response
|
||||
127
prometheus/tests/lang_graph/nodes/test_edit_message_node.py
Normal file
127
prometheus/tests/lang_graph/nodes/test_edit_message_node.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.lang_graph.nodes.edit_message_node import EditMessageNode
|
||||
from prometheus.models.context import Context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def edit_node():
|
||||
return EditMessageNode(
|
||||
context_key="bug_fix_context", analyzer_message_key="issue_bug_analyzer_messages"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"issue_title": "Test Bug",
|
||||
"issue_body": "This is a test bug description",
|
||||
"issue_comments": [
|
||||
{"username": "user1", "comment": "Comment 1"},
|
||||
{"username": "user2", "comment": "Comment 2"},
|
||||
],
|
||||
"bug_fix_context": [
|
||||
Context(
|
||||
relative_path="foobar.py",
|
||||
content="# Context 1",
|
||||
start_line_number=1,
|
||||
end_line_number=1,
|
||||
)
|
||||
],
|
||||
"issue_bug_analyzer_messages": ["Analysis message"],
|
||||
}
|
||||
|
||||
|
||||
def test_first_message_formatting(edit_node, base_state):
|
||||
# Using context managers for patching
|
||||
with patch(
|
||||
"prometheus.lang_graph.nodes.edit_message_node.format_issue_info"
|
||||
) as mock_format_issue:
|
||||
with patch(
|
||||
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
|
||||
) as mock_last_message:
|
||||
mock_format_issue.return_value = "Formatted Issue Info"
|
||||
mock_last_message.return_value = "Last Analysis Message"
|
||||
|
||||
result = edit_node(base_state)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "edit_messages" in result
|
||||
assert len(result["edit_messages"]) == 1
|
||||
assert isinstance(result["edit_messages"][0], HumanMessage)
|
||||
|
||||
message_content = result["edit_messages"][0].content
|
||||
assert "Formatted Issue Info" in message_content
|
||||
assert "# Context 1" in message_content
|
||||
assert "Last Analysis Message" in message_content
|
||||
|
||||
|
||||
def test_followup_message_with_build_fail(edit_node, base_state):
|
||||
# Add build failure to state
|
||||
base_state["build_fail_log"] = "Build failed: error in compilation"
|
||||
|
||||
with patch(
|
||||
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
|
||||
) as mock_last_message:
|
||||
mock_last_message.return_value = "Last Analysis Message"
|
||||
|
||||
result = edit_node(base_state)
|
||||
message_content = result["edit_messages"][0].content
|
||||
|
||||
assert "Build failed: error in compilation" in message_content
|
||||
assert "Please implement these revised changes carefully" in message_content
|
||||
|
||||
|
||||
def test_followup_message_with_test_fail(edit_node, base_state):
|
||||
# Add test failure to state
|
||||
base_state["reproducing_test_fail_log"] = "Test failed: assertion error"
|
||||
|
||||
with patch(
|
||||
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
|
||||
) as mock_last_message:
|
||||
mock_last_message.return_value = "Last Analysis Message"
|
||||
|
||||
result = edit_node(base_state)
|
||||
message_content = result["edit_messages"][0].content
|
||||
|
||||
assert "Test failed: assertion error" in message_content
|
||||
assert "Please implement these revised changes carefully" in message_content
|
||||
|
||||
|
||||
def test_followup_message_with_existing_test_fail(edit_node, base_state):
|
||||
# Add existing test failure to state
|
||||
base_state["existing_test_fail_log"] = "Existing test failed"
|
||||
|
||||
with patch(
|
||||
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
|
||||
) as mock_last_message:
|
||||
mock_last_message.return_value = "Last Analysis Message"
|
||||
|
||||
result = edit_node(base_state)
|
||||
message_content = result["edit_messages"][0].content
|
||||
|
||||
assert "Existing test failed" in message_content
|
||||
assert "Please implement these revised changes carefully" in message_content
|
||||
|
||||
|
||||
def test_error_priority(edit_node, base_state):
|
||||
# Add multiple error types to test priority handling
|
||||
base_state["reproducing_test_fail_log"] = "Test failed"
|
||||
base_state["build_fail_log"] = "Build failed"
|
||||
base_state["existing_test_fail_log"] = "Existing test failed"
|
||||
|
||||
with patch(
|
||||
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
|
||||
) as mock_last_message:
|
||||
mock_last_message.return_value = "Last Analysis Message"
|
||||
|
||||
result = edit_node(base_state)
|
||||
message_content = result["edit_messages"][0].content
|
||||
|
||||
# Should prioritize reproducing test failure
|
||||
assert "Test failed" in message_content
|
||||
assert "Build failed" not in message_content
|
||||
assert "Existing test failed" not in message_content
|
||||
41
prometheus/tests/lang_graph/nodes/test_edit_node.py
Normal file
41
prometheus/tests/lang_graph/nodes/test_edit_node.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.nodes.edit_node import EditNode
|
||||
from tests.test_utils.fixtures import temp_test_dir # noqa: F401
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(responses=["File edit completed successfully"])
|
||||
|
||||
|
||||
def test_init_edit_node(mock_kg, fake_llm, temp_test_dir): # noqa: F811
|
||||
"""Test EditNode initialization."""
|
||||
node = EditNode(fake_llm, temp_test_dir, mock_kg)
|
||||
|
||||
assert isinstance(node.system_prompt, SystemMessage)
|
||||
assert len(node.tools) == 5 # Should have 5 file operation tools
|
||||
assert node.model_with_tools is not None
|
||||
|
||||
|
||||
def test_call_method_basic(mock_kg, fake_llm, temp_test_dir): # noqa: F811
|
||||
"""Test basic call functionality without tool execution."""
|
||||
node = EditNode(fake_llm, temp_test_dir, mock_kg)
|
||||
state = {"edit_messages": [HumanMessage(content="Make the following changes: ...")]}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "edit_messages" in result
|
||||
assert len(result["edit_messages"]) == 1
|
||||
assert result["edit_messages"][0].content == "File edit completed successfully"
|
||||
64
prometheus/tests/lang_graph/nodes/test_general_build_node.py
Normal file
64
prometheus/tests/lang_graph/nodes/test_general_build_node.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.nodes.general_build_node import GeneralBuildNode
|
||||
from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
kg.get_file_tree.return_value = ".\n├── src\n│ └── main.py\n└── build.gradle"
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(responses=["Build command executed successfully"])
|
||||
|
||||
|
||||
def test_format_human_message_basic(mock_container, mock_kg, fake_llm):
|
||||
"""Test basic human message formatting."""
|
||||
node = GeneralBuildNode(fake_llm, mock_container, mock_kg)
|
||||
state = BuildAndTestState({})
|
||||
|
||||
message = node.format_human_message(state)
|
||||
|
||||
assert isinstance(message, HumanMessage)
|
||||
assert "project structure is:" in message.content
|
||||
assert mock_kg.get_file_tree() in message.content
|
||||
|
||||
|
||||
def test_format_human_message_with_build_summary(mock_container, mock_kg, fake_llm):
|
||||
"""Test message formatting with build command summary."""
|
||||
node = GeneralBuildNode(fake_llm, mock_container, mock_kg)
|
||||
state = BuildAndTestState({"build_command_summary": "Previous build used gradle"})
|
||||
|
||||
message = node.format_human_message(state)
|
||||
|
||||
assert "Previous build used gradle" in message.content
|
||||
assert "The previous build summary is:" in message.content
|
||||
|
||||
|
||||
def test_call_method_with_no_build(mock_container, mock_kg, fake_llm):
|
||||
"""Test __call__ method when exist_build is False."""
|
||||
node = GeneralBuildNode(fake_llm, mock_container, mock_kg)
|
||||
state = BuildAndTestState({"exist_build": False})
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "build_messages" in result
|
||||
assert len(result["build_messages"]) == 1
|
||||
assert (
|
||||
"Previous agent determined there is no build system" in result["build_messages"][0].content
|
||||
)
|
||||
75
prometheus/tests/lang_graph/nodes/test_general_test_node.py
Normal file
75
prometheus/tests/lang_graph/nodes/test_general_test_node.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.nodes.general_test_node import GeneralTestNode
|
||||
from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
kg.get_file_tree.return_value = "./\n├── tests/\n│ └── test_main.py"
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(responses=["Tests executed successfully"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return BuildAndTestState(
|
||||
{"exist_test": True, "test_messages": [], "test_command_summary": None}
|
||||
)
|
||||
|
||||
|
||||
def test_format_human_message(mock_container, mock_kg, fake_llm, basic_state):
|
||||
"""Test basic message formatting."""
|
||||
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
|
||||
message = node.format_human_message(basic_state)
|
||||
|
||||
assert isinstance(message, HumanMessage)
|
||||
assert mock_kg.get_file_tree() in message.content
|
||||
|
||||
|
||||
def test_call_with_no_tests(mock_container, mock_kg, fake_llm):
|
||||
"""Test behavior when no tests exist."""
|
||||
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
|
||||
state = BuildAndTestState({"exist_test": False})
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "build_messages" in result
|
||||
assert "no test framework" in result["build_messages"][0].content
|
||||
|
||||
|
||||
def test_call_normal_execution(mock_container, mock_kg, fake_llm, basic_state):
|
||||
"""Test normal execution flow."""
|
||||
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
|
||||
|
||||
result = node(basic_state)
|
||||
|
||||
assert "test_messages" in result
|
||||
assert len(result["test_messages"]) == 1
|
||||
assert result["test_messages"][0].content == "Tests executed successfully"
|
||||
|
||||
|
||||
def test_format_human_message_with_summary(mock_container, mock_kg, fake_llm):
|
||||
"""Test message formatting with test summary."""
|
||||
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
|
||||
state = BuildAndTestState({"test_command_summary": "Previous test used pytest"})
|
||||
|
||||
message = node.format_human_message(state)
|
||||
|
||||
assert "Previous test used pytest" in message.content
|
||||
35
prometheus/tests/lang_graph/nodes/test_git_diff_node.py
Normal file
35
prometheus/tests/lang_graph/nodes/test_git_diff_node.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.lang_graph.nodes.git_diff_node import GitDiffNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.get_diff.return_value = "sample diff content"
|
||||
return git_repo
|
||||
|
||||
|
||||
def test_git_diff_node(mock_git_repo):
|
||||
node = GitDiffNode(mock_git_repo, "patch")
|
||||
|
||||
# Execute
|
||||
result = node({})
|
||||
|
||||
# Assert
|
||||
assert result == {"patch": "sample diff content"}
|
||||
mock_git_repo.get_diff.assert_called_with(None)
|
||||
|
||||
|
||||
def test_git_diff_node_with_excluded_files(mock_git_repo):
|
||||
node = GitDiffNode(mock_git_repo, "patch", "excluded_file")
|
||||
|
||||
# Execute
|
||||
result = node({"excluded_file": "/foo/bar.py"})
|
||||
|
||||
# Assert
|
||||
assert result == {"patch": "sample diff content"}
|
||||
mock_git_repo.get_diff.assert_called_with(["/foo/bar.py"])
|
||||
@@ -0,0 +1,135 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(responses=["Bug analysis completed successfully"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_with_tool_call():
|
||||
"""LLM that simulates making a web_search tool call."""
|
||||
return FakeListChatWithToolsModel(
|
||||
responses=["I need to search for information about this error."]
|
||||
)
|
||||
|
||||
|
||||
def test_init_issue_bug_analyzer_node(fake_llm):
|
||||
"""Test IssueBugAnalyzerNode initialization."""
|
||||
node = IssueBugAnalyzerNode(fake_llm)
|
||||
|
||||
assert node.system_prompt is not None
|
||||
assert len(node.tools) == 1 # Should have web_search tool
|
||||
assert node.tools[0].name == "web_search"
|
||||
assert node.model_with_tools is not None
|
||||
|
||||
|
||||
def test_call_method_basic(fake_llm):
|
||||
"""Test basic call functionality."""
|
||||
node = IssueBugAnalyzerNode(fake_llm)
|
||||
state = {"issue_bug_analyzer_messages": [HumanMessage(content="Please analyze this bug: ...")]}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "issue_bug_analyzer_messages" in result
|
||||
assert len(result["issue_bug_analyzer_messages"]) == 1
|
||||
assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully"
|
||||
|
||||
|
||||
def test_web_search_tool_integration(fake_llm_with_tool_call):
|
||||
"""Test that the web_search tool is properly integrated and can be called."""
|
||||
node = IssueBugAnalyzerNode(fake_llm_with_tool_call)
|
||||
state = {
|
||||
"issue_bug_analyzer_messages": [
|
||||
HumanMessage(
|
||||
content="I'm getting a ValueError in my Python code. Can you help analyze it?"
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
result = node(state)
|
||||
|
||||
# Verify the result contains the response message
|
||||
assert "issue_bug_analyzer_messages" in result
|
||||
assert len(result["issue_bug_analyzer_messages"]) == 1
|
||||
assert (
|
||||
result["issue_bug_analyzer_messages"][0].content
|
||||
== "I need to search for information about this error."
|
||||
)
|
||||
|
||||
|
||||
def test_web_search_tool_call_with_correct_parameters(fake_llm):
|
||||
"""Test that web_search tool has correct configuration and can be called."""
|
||||
node = IssueBugAnalyzerNode(fake_llm)
|
||||
|
||||
# Test that the tool exists and has correct configuration
|
||||
web_search_tool = node.tools[0]
|
||||
assert web_search_tool.name == "web_search"
|
||||
assert "technical information" in web_search_tool.description.lower()
|
||||
|
||||
# Test that the tool has the correct args schema
|
||||
assert hasattr(web_search_tool, "args_schema")
|
||||
assert web_search_tool.args_schema is not None
|
||||
|
||||
|
||||
def test_system_prompt_contains_web_search_info(fake_llm):
|
||||
"""Test that the system prompt mentions web_search tool."""
|
||||
node = IssueBugAnalyzerNode(fake_llm)
|
||||
|
||||
system_prompt_content = node.system_prompt.content.lower()
|
||||
assert "web_search" in system_prompt_content
|
||||
assert "technical information" in system_prompt_content
|
||||
|
||||
|
||||
def test_web_search_tool_schema_validation(fake_llm):
|
||||
"""Test that the web_search tool has proper input validation."""
|
||||
node = IssueBugAnalyzerNode(fake_llm)
|
||||
web_search_tool = node.tools[0]
|
||||
|
||||
# Check that the tool has an args_schema
|
||||
assert hasattr(web_search_tool, "args_schema")
|
||||
assert web_search_tool.args_schema is not None
|
||||
|
||||
# Test with valid input
|
||||
valid_input = {"query": "Python debugging techniques"}
|
||||
# This should not raise an exception
|
||||
validated_input = web_search_tool.args_schema(**valid_input)
|
||||
assert validated_input.query == "Python debugging techniques"
|
||||
|
||||
|
||||
def test_multiple_tool_calls_in_conversation(fake_llm):
|
||||
"""Test handling multiple web_search calls in a conversation."""
|
||||
node = IssueBugAnalyzerNode(fake_llm)
|
||||
|
||||
# Simulate a conversation with tool calls
|
||||
state = {
|
||||
"issue_bug_analyzer_messages": [
|
||||
HumanMessage(content="Analyze this bug: ImportError in my application"),
|
||||
AIMessage(
|
||||
content="Let me search for information about this error.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
name="web_search",
|
||||
args={"query": "Python ImportError debugging"},
|
||||
id="call_1",
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content="Search results: ImportError occurs when...", tool_call_id="call_1"
|
||||
),
|
||||
HumanMessage(content="The error still persists after trying the suggested fixes"),
|
||||
]
|
||||
}
|
||||
|
||||
result = node(state)
|
||||
|
||||
assert "issue_bug_analyzer_messages" in result
|
||||
assert len(result["issue_bug_analyzer_messages"]) == 1
|
||||
# The new response should be added to the conversation
|
||||
assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully"
|
||||
@@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_bug_responder_node import IssueBugResponderNode
|
||||
from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(
|
||||
responses=["Thank you for reporting this issue. The fix has been implemented and verified."]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return IssueBugState(
|
||||
issue_title="Test Bug",
|
||||
issue_body="Found a bug in the code",
|
||||
issue_comments=[
|
||||
{"username": "user1", "comment": "This affects my workflow"},
|
||||
{"username": "user2", "comment": "Same issue here"},
|
||||
],
|
||||
edit_patch="Fixed array index calculation",
|
||||
passed_reproducing_test=True,
|
||||
passed_regression_test=True,
|
||||
passed_existing_test=True,
|
||||
run_build=True,
|
||||
run_existing_test=True,
|
||||
run_regression_test=True,
|
||||
run_reproduce_test=True,
|
||||
number_of_candidate_patch=6,
|
||||
reproduced_bug=True,
|
||||
reproduced_bug_file="mock.py",
|
||||
reproduced_bug_patch="mock patch to reproduce the bug",
|
||||
reproduced_bug_commands="pytest test_bug.py",
|
||||
selected_regression_tests=["tests:tests"],
|
||||
issue_response="Mock Response",
|
||||
)
|
||||
|
||||
|
||||
def test_format_human_message_basic(fake_llm, basic_state):
|
||||
"""Test basic human message formatting."""
|
||||
node = IssueBugResponderNode(fake_llm)
|
||||
message = node.format_human_message(basic_state)
|
||||
|
||||
assert "Test Bug" in message.content
|
||||
assert "Found a bug in the code" in message.content
|
||||
assert "user1" in message.content
|
||||
assert "user2" in message.content
|
||||
assert "Fixed array index" in message.content
|
||||
|
||||
|
||||
def test_format_human_message_verification(fake_llm, basic_state):
|
||||
"""Test verification message formatting."""
|
||||
node = IssueBugResponderNode(fake_llm)
|
||||
message = node.format_human_message(basic_state)
|
||||
|
||||
assert "✓ The bug reproducing test passed" in message.content
|
||||
assert "✓ All selected regression tests passes successfully" in message.content
|
||||
assert "✓ All existing tests pass successfully" in message.content
|
||||
|
||||
|
||||
def test_format_human_message_no_verification(fake_llm):
|
||||
"""Test message formatting without verifications."""
|
||||
state = IssueBugState(
|
||||
issue_title="Test Bug",
|
||||
issue_body="Bug description",
|
||||
issue_comments=[],
|
||||
edit_patch="Fixed array index calculation",
|
||||
passed_reproducing_test=False,
|
||||
passed_existing_test=False,
|
||||
passed_regression_test=False,
|
||||
)
|
||||
|
||||
node = IssueBugResponderNode(fake_llm)
|
||||
message = node.format_human_message(state)
|
||||
|
||||
assert "✓ The bug reproducing test passed" not in message.content
|
||||
assert "✓ All selected regression tests passes successfully" not in message.content
|
||||
assert "✓ All existing tests pass successfully" not in message.content
|
||||
|
||||
|
||||
def test_format_human_message_partial_verification(fake_llm):
|
||||
"""Test message formatting with partial verifications."""
|
||||
state = IssueBugState(
|
||||
issue_title="Test Bug",
|
||||
issue_body="Bug description",
|
||||
issue_comments=[],
|
||||
edit_patch="Fixed array index calculation",
|
||||
passed_reproducing_test=True,
|
||||
passed_existing_test=True,
|
||||
passed_regression_test=True,
|
||||
)
|
||||
|
||||
node = IssueBugResponderNode(fake_llm)
|
||||
message = node.format_human_message(state)
|
||||
|
||||
assert "✓ The bug reproducing test passed" in message.content
|
||||
assert "✓ Build passes successfully" not in message.content
|
||||
assert "✓ All existing tests pass successfully" in message.content
|
||||
|
||||
|
||||
def test_call_method(fake_llm, basic_state):
|
||||
"""Test the call method execution."""
|
||||
node = IssueBugResponderNode(fake_llm)
|
||||
result = node(basic_state)
|
||||
|
||||
assert "issue_response" in result
|
||||
assert (
|
||||
result["issue_response"]
|
||||
== "Thank you for reporting this issue. The fix has been implemented and verified."
|
||||
)
|
||||
@@ -0,0 +1,98 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_documentation_analyzer_message_node import (
|
||||
IssueDocumentationAnalyzerMessageNode,
|
||||
)
|
||||
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
|
||||
from prometheus.models.context import Context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return IssueDocumentationState(
|
||||
issue_title="Update API documentation",
|
||||
issue_body="The API documentation needs to be updated",
|
||||
issue_comments=[
|
||||
{"username": "user1", "comment": "Please add examples"},
|
||||
],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="Find API docs",
|
||||
documentation_context=[
|
||||
Context(
|
||||
relative_path="/docs/api.md",
|
||||
content="# API Documentation\n\nAPI Documentation content",
|
||||
)
|
||||
],
|
||||
issue_documentation_analyzer_messages=[],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
|
||||
def test_init_issue_documentation_analyzer_message_node():
|
||||
"""Test IssueDocumentationAnalyzerMessageNode initialization."""
|
||||
node = IssueDocumentationAnalyzerMessageNode()
|
||||
assert node is not None
|
||||
|
||||
|
||||
def test_call_method_creates_message(basic_state):
|
||||
"""Test that the node creates a human message."""
|
||||
node = IssueDocumentationAnalyzerMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
assert "issue_documentation_analyzer_messages" in result
|
||||
assert len(result["issue_documentation_analyzer_messages"]) == 1
|
||||
assert isinstance(result["issue_documentation_analyzer_messages"][0], HumanMessage)
|
||||
|
||||
|
||||
def test_message_contains_issue_info(basic_state):
|
||||
"""Test that the message contains issue information."""
|
||||
node = IssueDocumentationAnalyzerMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
message_content = result["issue_documentation_analyzer_messages"][0].content
|
||||
assert "Update API documentation" in message_content
|
||||
|
||||
|
||||
def test_message_contains_context(basic_state):
|
||||
"""Test that the message includes documentation context."""
|
||||
node = IssueDocumentationAnalyzerMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
message_content = result["issue_documentation_analyzer_messages"][0].content
|
||||
# Should include context or reference to it
|
||||
assert "context" in message_content.lower() or "API Documentation" in message_content
|
||||
|
||||
|
||||
def test_message_includes_analysis_instructions(basic_state):
|
||||
"""Test that the message includes analysis instructions."""
|
||||
node = IssueDocumentationAnalyzerMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
message_content = result["issue_documentation_analyzer_messages"][0].content
|
||||
# Should include instructions for analysis
|
||||
assert "plan" in message_content.lower() or "analyze" in message_content.lower()
|
||||
|
||||
|
||||
def test_call_with_empty_context():
|
||||
"""Test the node with empty documentation context."""
|
||||
state = IssueDocumentationState(
|
||||
issue_title="Create new docs",
|
||||
issue_body="Create documentation for new feature",
|
||||
issue_comments=[],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
node = IssueDocumentationAnalyzerMessageNode()
|
||||
result = node(state)
|
||||
|
||||
assert "issue_documentation_analyzer_messages" in result
|
||||
assert len(result["issue_documentation_analyzer_messages"]) == 1
|
||||
@@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_documentation_analyzer_node import (
|
||||
IssueDocumentationAnalyzerNode,
|
||||
)
|
||||
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(
|
||||
responses=[
|
||||
"Documentation Plan:\n1. Update README.md with new API documentation\n"
|
||||
"2. Add code examples\n3. Update table of contents"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return IssueDocumentationState(
|
||||
issue_title="Update API documentation",
|
||||
issue_body="The API documentation needs to be updated with the new endpoints",
|
||||
issue_comments=[
|
||||
{"username": "user1", "comment": "Please include examples"},
|
||||
{"username": "user2", "comment": "Add authentication details"},
|
||||
],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="Find API documentation files",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[
|
||||
HumanMessage(content="Please analyze the documentation request and provide a plan")
|
||||
],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
|
||||
def test_init_issue_documentation_analyzer_node(fake_llm):
|
||||
"""Test IssueDocumentationAnalyzerNode initialization."""
|
||||
node = IssueDocumentationAnalyzerNode(fake_llm)
|
||||
|
||||
assert node.system_prompt is not None
|
||||
assert node.web_search_tool is not None
|
||||
assert len(node.tools) == 1 # Should have web search tool
|
||||
assert node.model_with_tools is not None
|
||||
|
||||
|
||||
def test_call_method_basic(fake_llm, basic_state):
|
||||
"""Test basic call functionality."""
|
||||
node = IssueDocumentationAnalyzerNode(fake_llm)
|
||||
result = node(basic_state)
|
||||
|
||||
assert "issue_documentation_analyzer_messages" in result
|
||||
assert len(result["issue_documentation_analyzer_messages"]) == 1
|
||||
assert "Documentation Plan" in result["issue_documentation_analyzer_messages"][0].content
|
||||
|
||||
|
||||
def test_call_method_with_empty_messages(fake_llm):
|
||||
"""Test call method with empty message history."""
|
||||
state = IssueDocumentationState(
|
||||
issue_title="Test",
|
||||
issue_body="Test body",
|
||||
issue_comments=[],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
node = IssueDocumentationAnalyzerNode(fake_llm)
|
||||
result = node(state)
|
||||
|
||||
assert "issue_documentation_analyzer_messages" in result
|
||||
assert len(result["issue_documentation_analyzer_messages"]) == 1
|
||||
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_documentation_context_message_node import (
|
||||
IssueDocumentationContextMessageNode,
|
||||
)
|
||||
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return IssueDocumentationState(
|
||||
issue_title="Update API documentation",
|
||||
issue_body="The API documentation needs to be updated with new endpoints",
|
||||
issue_comments=[
|
||||
{"username": "user1", "comment": "Please include examples"},
|
||||
],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
|
||||
def test_init_issue_documentation_context_message_node():
|
||||
"""Test IssueDocumentationContextMessageNode initialization."""
|
||||
node = IssueDocumentationContextMessageNode()
|
||||
assert node is not None
|
||||
|
||||
|
||||
def test_call_method_generates_query(basic_state):
|
||||
"""Test that the node generates a documentation query."""
|
||||
node = IssueDocumentationContextMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
assert "documentation_query" in result
|
||||
assert len(result["documentation_query"]) > 0
|
||||
|
||||
|
||||
def test_query_contains_issue_info(basic_state):
|
||||
"""Test that the query contains issue information."""
|
||||
node = IssueDocumentationContextMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
query = result["documentation_query"]
|
||||
assert "Update API documentation" in query or "API documentation" in query
|
||||
|
||||
|
||||
def test_query_includes_instructions(basic_state):
|
||||
"""Test that the query includes documentation finding instructions."""
|
||||
node = IssueDocumentationContextMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
query = result["documentation_query"]
|
||||
# Should include instructions about finding documentation
|
||||
assert "documentation" in query.lower() or "find" in query.lower()
|
||||
|
||||
|
||||
def test_call_with_empty_comments():
|
||||
"""Test the node with empty comments."""
|
||||
state = IssueDocumentationState(
|
||||
issue_title="Test title",
|
||||
issue_body="Test body",
|
||||
issue_comments=[],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
node = IssueDocumentationContextMessageNode()
|
||||
result = node(state)
|
||||
|
||||
assert "documentation_query" in result
|
||||
assert len(result["documentation_query"]) > 0
|
||||
@@ -0,0 +1,117 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_documentation_edit_message_node import (
|
||||
IssueDocumentationEditMessageNode,
|
||||
)
|
||||
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
|
||||
from prometheus.models.context import Context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return IssueDocumentationState(
|
||||
issue_title="Update API documentation",
|
||||
issue_body="The API documentation needs to be updated",
|
||||
issue_comments=[],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="Find API docs",
|
||||
documentation_context=[
|
||||
Context(
|
||||
relative_path="/docs/api.md",
|
||||
content="# API Documentation\n\nAPI Documentation content",
|
||||
)
|
||||
],
|
||||
issue_documentation_analyzer_messages=[
|
||||
AIMessage(content="Plan:\n1. Update README.md\n2. Add new examples\n3. Fix typos")
|
||||
],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
|
||||
def test_init_issue_documentation_edit_message_node():
|
||||
"""Test IssueDocumentationEditMessageNode initialization."""
|
||||
node = IssueDocumentationEditMessageNode()
|
||||
assert node is not None
|
||||
|
||||
|
||||
def test_call_method_creates_message(basic_state):
|
||||
"""Test that the node creates a human message."""
|
||||
node = IssueDocumentationEditMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
assert "edit_messages" in result
|
||||
assert len(result["edit_messages"]) == 1
|
||||
assert isinstance(result["edit_messages"][0], HumanMessage)
|
||||
|
||||
|
||||
def test_message_contains_plan(basic_state):
|
||||
"""Test that the message contains the documentation plan."""
|
||||
node = IssueDocumentationEditMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
message_content = result["edit_messages"][0].content
|
||||
assert "Plan:" in message_content or "Update README.md" in message_content
|
||||
|
||||
|
||||
def test_message_contains_context(basic_state):
|
||||
"""Test that the message includes documentation context."""
|
||||
node = IssueDocumentationEditMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
message_content = result["edit_messages"][0].content
|
||||
# Should include context
|
||||
assert "context" in message_content.lower() or "API Documentation" in message_content
|
||||
|
||||
|
||||
def test_message_includes_edit_instructions(basic_state):
|
||||
"""Test that the message includes editing instructions."""
|
||||
node = IssueDocumentationEditMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
message_content = result["edit_messages"][0].content
|
||||
# Should include instructions for implementing changes
|
||||
assert any(
|
||||
keyword in message_content.lower() for keyword in ["implement", "edit", "changes", "file"]
|
||||
)
|
||||
|
||||
|
||||
def test_call_with_empty_context():
|
||||
"""Test the node with empty documentation context."""
|
||||
state = IssueDocumentationState(
|
||||
issue_title="Create docs",
|
||||
issue_body="Create new documentation",
|
||||
issue_comments=[],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[AIMessage(content="Create new documentation files")],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
node = IssueDocumentationEditMessageNode()
|
||||
result = node(state)
|
||||
|
||||
assert "edit_messages" in result
|
||||
assert len(result["edit_messages"]) == 1
|
||||
|
||||
|
||||
def test_extracts_last_analyzer_message(basic_state):
|
||||
"""Test that the node extracts the last message from analyzer history."""
|
||||
# Add multiple messages to analyzer history
|
||||
basic_state["issue_documentation_analyzer_messages"] = [
|
||||
AIMessage(content="First message"),
|
||||
AIMessage(content="Second message"),
|
||||
AIMessage(content="Final plan: Update docs"),
|
||||
]
|
||||
|
||||
node = IssueDocumentationEditMessageNode()
|
||||
result = node(basic_state)
|
||||
|
||||
message_content = result["edit_messages"][0].content
|
||||
# Should contain the final plan
|
||||
assert "Final plan" in message_content
|
||||
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_documentation_responder_node import (
|
||||
IssueDocumentationResponderNode,
|
||||
)
|
||||
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(
|
||||
responses=[
|
||||
"The documentation has been successfully updated. "
|
||||
"I've added new API endpoint documentation and included examples as requested."
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return IssueDocumentationState(
|
||||
issue_title="Update API documentation",
|
||||
issue_body="The API documentation needs to be updated with the new endpoints",
|
||||
issue_comments=[
|
||||
{"username": "user1", "comment": "Please include examples"},
|
||||
],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="Find API documentation",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[
|
||||
AIMessage(content="Plan: Update README.md with new API endpoints and add examples")
|
||||
],
|
||||
edit_messages=[],
|
||||
edit_patch="diff --git a/README.md b/README.md\n+New API documentation",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
|
||||
def test_init_issue_documentation_responder_node(fake_llm):
|
||||
"""Test IssueDocumentationResponderNode initialization."""
|
||||
node = IssueDocumentationResponderNode(fake_llm)
|
||||
|
||||
assert node.model is not None
|
||||
|
||||
|
||||
def test_call_method_basic(fake_llm, basic_state):
|
||||
"""Test basic call functionality."""
|
||||
node = IssueDocumentationResponderNode(fake_llm)
|
||||
result = node(basic_state)
|
||||
|
||||
assert "issue_response" in result
|
||||
assert "successfully updated" in result["issue_response"]
|
||||
assert len(result["issue_response"]) > 0
|
||||
|
||||
|
||||
def test_call_method_with_patch(fake_llm, basic_state):
|
||||
"""Test response generation with patch."""
|
||||
node = IssueDocumentationResponderNode(fake_llm)
|
||||
result = node(basic_state)
|
||||
|
||||
assert "issue_response" in result
|
||||
assert isinstance(result["issue_response"], str)
|
||||
|
||||
|
||||
def test_call_method_without_patch(fake_llm):
|
||||
"""Test response generation without patch."""
|
||||
state = IssueDocumentationState(
|
||||
issue_title="Update docs",
|
||||
issue_body="Please update the documentation",
|
||||
issue_comments=[],
|
||||
max_refined_query_loop=3,
|
||||
documentation_query="",
|
||||
documentation_context=[],
|
||||
issue_documentation_analyzer_messages=[AIMessage(content="Documentation plan created")],
|
||||
edit_messages=[],
|
||||
edit_patch="",
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
node = IssueDocumentationResponderNode(fake_llm)
|
||||
result = node(state)
|
||||
|
||||
assert "issue_response" in result
|
||||
assert len(result["issue_response"]) > 0
|
||||
|
||||
|
||||
def test_response_includes_issue_details(fake_llm, basic_state):
|
||||
"""Test that the generated response is relevant to the issue."""
|
||||
node = IssueDocumentationResponderNode(fake_llm)
|
||||
result = node(basic_state)
|
||||
|
||||
assert "issue_response" in result
|
||||
# The response should be a string with meaningful content
|
||||
assert len(result["issue_response"]) > 10
|
||||
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.lang_graph.nodes.issue_feature_responder_node import IssueFeatureResponderNode
|
||||
from prometheus.lang_graph.subgraphs.issue_feature_state import IssueFeatureState
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm():
|
||||
return FakeListChatWithToolsModel(
|
||||
responses=[
|
||||
"Thank you for requesting this feature. The implementation has been completed and is ready for review."
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_state():
|
||||
return IssueFeatureState(
|
||||
issue_title="Add dark mode support",
|
||||
issue_body="Please add dark mode to the application",
|
||||
issue_comments=[
|
||||
{"username": "user1", "comment": "This would be great!"},
|
||||
{"username": "user2", "comment": "I need this feature"},
|
||||
],
|
||||
final_patch="Added dark mode theme switching functionality",
|
||||
run_regression_test=True,
|
||||
number_of_candidate_patch=3,
|
||||
selected_regression_tests=["tests:tests"],
|
||||
issue_response="Mock Response",
|
||||
)
|
||||
|
||||
|
||||
def test_format_human_message_basic(fake_llm, basic_state):
|
||||
"""Test basic human message formatting."""
|
||||
node = IssueFeatureResponderNode(fake_llm)
|
||||
message = node.format_human_message(basic_state)
|
||||
|
||||
assert "Add dark mode support" in message.content
|
||||
assert "Please add dark mode to the application" in message.content
|
||||
assert "user1" in message.content
|
||||
assert "user2" in message.content
|
||||
assert "Added dark mode theme switching functionality" in message.content
|
||||
|
||||
|
||||
def test_format_human_message_with_regression_tests(fake_llm, basic_state):
|
||||
"""Test message formatting with regression tests."""
|
||||
# Add tested_patch_result to simulate passed tests
|
||||
from prometheus.models.test_patch_result import TestedPatchResult
|
||||
|
||||
basic_state["tested_patch_result"] = [
|
||||
TestedPatchResult(patch="test patch", passed=True, regression_test_failure_log="")
|
||||
]
|
||||
|
||||
node = IssueFeatureResponderNode(fake_llm)
|
||||
message = node.format_human_message(basic_state)
|
||||
|
||||
assert "✓ All selected regression tests passed successfully" in message.content
|
||||
|
||||
|
||||
def test_format_human_message_no_tests(fake_llm):
|
||||
"""Test message formatting without tests."""
|
||||
state = IssueFeatureState(
|
||||
issue_title="Add feature",
|
||||
issue_body="Feature description",
|
||||
issue_comments=[],
|
||||
final_patch="Implementation patch",
|
||||
run_regression_test=False,
|
||||
number_of_candidate_patch=1,
|
||||
selected_regression_tests=[],
|
||||
issue_response="",
|
||||
)
|
||||
|
||||
node = IssueFeatureResponderNode(fake_llm)
|
||||
message = node.format_human_message(state)
|
||||
|
||||
assert "No automated tests were run for this feature implementation." in message.content
|
||||
|
||||
|
||||
def test_call_method(fake_llm, basic_state):
|
||||
"""Test the call method execution."""
|
||||
node = IssueFeatureResponderNode(fake_llm)
|
||||
result = node(basic_state)
|
||||
|
||||
assert "issue_response" in result
|
||||
assert (
|
||||
result["issue_response"]
|
||||
== "Thank you for requesting this feature. The implementation has been completed and is ready for review."
|
||||
)
|
||||
@@ -0,0 +1,20 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode
|
||||
|
||||
|
||||
def test_reset_messages_node():
|
||||
reset_build_messages_node = ResetMessagesNode("build_messages")
|
||||
|
||||
state = {
|
||||
"build_messages": [HumanMessage(content="message 1"), HumanMessage(content="message 2")],
|
||||
"test_messages": [HumanMessage(content="message 3"), HumanMessage(content="message 4")],
|
||||
}
|
||||
|
||||
reset_build_messages_node(state)
|
||||
|
||||
assert "build_messages" in state
|
||||
assert len(state["build_messages"]) == 0
|
||||
|
||||
assert "test_messages" in state
|
||||
assert len(state["test_messages"]) == 2
|
||||
@@ -0,0 +1,24 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock
|
||||
|
||||
from prometheus.docker.general_container import GeneralContainer
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.lang_graph.nodes.update_container_node import UpdateContainerNode
|
||||
|
||||
|
||||
def test_update_container_node():
|
||||
mocked_container = Mock(spec=GeneralContainer)
|
||||
mocked_container.is_running.return_value = True
|
||||
mocked_git_repo = Mock(spec=GitRepository)
|
||||
mocked_git_repo.get_diff.return_value = "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content"
|
||||
mocked_git_repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo")
|
||||
update_container_node = UpdateContainerNode(mocked_container, mocked_git_repo)
|
||||
|
||||
update_container_node(None)
|
||||
|
||||
assert mocked_git_repo.get_diff.call_count == 1
|
||||
assert mocked_container.is_running.call_count == 1
|
||||
assert mocked_container.update_files.call_count == 1
|
||||
mocked_container.update_files.assert_called_with(
|
||||
Path("/test/working/dir/repositories/repo"), [Path("newfile")], []
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.lang_graph.nodes.user_defined_build_node import UserDefinedBuildNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
container = Mock(spec=BaseContainer)
|
||||
container.run_build.return_value = "Build successful"
|
||||
return container
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def build_node(mock_container):
|
||||
return UserDefinedBuildNode(container=mock_container)
|
||||
|
||||
|
||||
def test_successful_build(build_node, mock_container):
|
||||
expected_output = "Build successful"
|
||||
|
||||
result = build_node(None)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "build_messages" in result
|
||||
assert len(result["build_messages"]) == 1
|
||||
assert isinstance(result["build_messages"][0], ToolMessage)
|
||||
assert result["build_messages"][0].content == expected_output
|
||||
mock_container.run_build.assert_called_once()
|
||||
@@ -0,0 +1,32 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.lang_graph.nodes.user_defined_test_node import UserDefinedTestNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
container = Mock(spec=BaseContainer)
|
||||
container.run_test.return_value = "Test successful"
|
||||
return container
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_node(mock_container):
|
||||
return UserDefinedTestNode(container=mock_container)
|
||||
|
||||
|
||||
def test_successful_test(test_node, mock_container):
|
||||
expected_output = "Test successful"
|
||||
|
||||
result = test_node(None)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "test_messages" in result
|
||||
assert len(result["test_messages"]) == 1
|
||||
assert isinstance(result["test_messages"][0], ToolMessage)
|
||||
assert result["test_messages"][0].content == expected_output
|
||||
mock_container.run_test.assert_called_once()
|
||||
0
prometheus/tests/lang_graph/subgraphs/__init__.py
Normal file
0
prometheus/tests/lang_graph/subgraphs/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.lang_graph.subgraphs.bug_fix_verification_subgraph import BugFixVerificationSubgraph
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_checkpointer():
|
||||
return Mock(spec=BaseCheckpointSaver)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.playground_path = "mock/playground/path"
|
||||
return git_repo
|
||||
|
||||
|
||||
def test_bug_fix_verification_subgraph_basic_initialization(
|
||||
mock_container,
|
||||
mock_git_repo,
|
||||
):
|
||||
"""Test that BugFixVerificationSubgraph initializes correctly with basic components."""
|
||||
fake_model = FakeListChatWithToolsModel(responses=[])
|
||||
|
||||
subgraph = BugFixVerificationSubgraph(fake_model, mock_container, mock_git_repo)
|
||||
|
||||
assert subgraph.subgraph is not None
|
||||
@@ -0,0 +1,49 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.subgraphs.bug_reproduction_subgraph import BugReproductionSubgraph
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
|
||||
kg.root_node_id = 0
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.playground_path = "mock/playground/path"
|
||||
return git_repo
|
||||
|
||||
|
||||
def test_bug_reproduction_subgraph_basic_initialization(mock_container, mock_kg, mock_git_repo):
|
||||
"""Test that BugReproductionSubgraph initializes correctly with basic components."""
|
||||
# Initialize fake model with empty responses
|
||||
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
|
||||
fake_base_model = FakeListChatWithToolsModel(responses=[])
|
||||
|
||||
# Initialize the subgraph
|
||||
subgraph = BugReproductionSubgraph(
|
||||
fake_advanced_model,
|
||||
fake_base_model,
|
||||
mock_container,
|
||||
mock_kg,
|
||||
mock_git_repo,
|
||||
None,
|
||||
)
|
||||
|
||||
# Verify the subgraph was created
|
||||
assert subgraph.subgraph is not None
|
||||
@@ -0,0 +1,47 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.subgraphs.build_and_test_subgraph import BuildAndTestSubgraph
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
return Mock(spec=KnowledgeGraph)
|
||||
|
||||
|
||||
def test_build_and_test_subgraph_basic_initialization(mock_container, mock_kg):
|
||||
"""Test that BuildAndTestSubgraph initializes correctly with basic components."""
|
||||
# Initialize fake model with empty responses
|
||||
fake_model = FakeListChatWithToolsModel(responses=[])
|
||||
|
||||
# Initialize the subgraph
|
||||
subgraph = BuildAndTestSubgraph(container=mock_container, model=fake_model, kg=mock_kg)
|
||||
|
||||
# Verify the subgraph was created
|
||||
assert subgraph.subgraph is not None
|
||||
|
||||
|
||||
def test_build_and_test_subgraph_with_commands(mock_container, mock_kg):
|
||||
"""Test that BuildAndTestSubgraph initializes correctly with build and test commands."""
|
||||
fake_model = FakeListChatWithToolsModel(responses=[])
|
||||
build_commands = ["make build"]
|
||||
test_commands = ["make test"]
|
||||
|
||||
subgraph = BuildAndTestSubgraph(
|
||||
container=mock_container,
|
||||
model=fake_model,
|
||||
kg=mock_kg,
|
||||
build_commands=build_commands,
|
||||
test_commands=test_commands,
|
||||
)
|
||||
|
||||
assert subgraph.subgraph is not None
|
||||
@@ -0,0 +1,69 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.subgraphs.issue_bug_subgraph import IssueBugSubgraph
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
# Configure the mock to return a list of AST node types
|
||||
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
|
||||
kg.root_node_id = 0
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.playground_path = "mock/playground/path"
|
||||
return git_repo
|
||||
|
||||
|
||||
def test_issue_bug_subgraph_basic_initialization(mock_container, mock_kg, mock_git_repo):
|
||||
"""Test that IssueBugSubgraph initializes correctly with basic components."""
|
||||
# Initialize fake model with empty responses
|
||||
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
|
||||
fake_base_model = FakeListChatWithToolsModel(responses=[])
|
||||
|
||||
# Initialize the subgraph with required parameters
|
||||
subgraph = IssueBugSubgraph(
|
||||
advanced_model=fake_advanced_model,
|
||||
base_model=fake_base_model,
|
||||
container=mock_container,
|
||||
kg=mock_kg,
|
||||
git_repo=mock_git_repo,
|
||||
repository_id=1,
|
||||
)
|
||||
|
||||
# Verify the subgraph was created
|
||||
assert subgraph.subgraph is not None
|
||||
|
||||
|
||||
def test_issue_bug_subgraph_with_commands(mock_container, mock_kg, mock_git_repo):
|
||||
"""Test that IssueBugSubgraph initializes correctly with build and test commands."""
|
||||
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
|
||||
fake_base_model = FakeListChatWithToolsModel(responses=[])
|
||||
test_commands = ["make test"]
|
||||
|
||||
subgraph = IssueBugSubgraph(
|
||||
advanced_model=fake_advanced_model,
|
||||
base_model=fake_base_model,
|
||||
container=mock_container,
|
||||
kg=mock_kg,
|
||||
git_repo=mock_git_repo,
|
||||
repository_id=1,
|
||||
test_commands=test_commands,
|
||||
)
|
||||
|
||||
assert subgraph.subgraph is not None
|
||||
@@ -0,0 +1,45 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.subgraphs.issue_classification_subgraph import (
|
||||
IssueClassificationSubgraph,
|
||||
)
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
# Configure the mock to return a list of AST node types
|
||||
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
|
||||
kg.root_node_id = 0
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.playground_path = "mock/playground/path"
|
||||
return git_repo
|
||||
|
||||
|
||||
def test_issue_classification_subgraph_basic_initialization(mock_kg, mock_git_repo):
|
||||
"""Test that IssueClassificationSubgraph initializes correctly with basic components."""
|
||||
# Initialize fake model with empty responses
|
||||
fake_model = FakeListChatWithToolsModel(responses=[])
|
||||
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
|
||||
|
||||
# Initialize the subgraph with required parameters
|
||||
subgraph = IssueClassificationSubgraph(
|
||||
advanced_model=fake_advanced_model,
|
||||
model=fake_model,
|
||||
kg=mock_kg,
|
||||
local_path=mock_git_repo.playground_path,
|
||||
repository_id=1,
|
||||
)
|
||||
|
||||
# Verify the subgraph was created
|
||||
assert subgraph.subgraph is not None
|
||||
@@ -0,0 +1,51 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.subgraphs.issue_documentation_subgraph import (
|
||||
IssueDocumentationSubgraph,
|
||||
)
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
# Configure the mock to return a list of AST node types
|
||||
kg.get_all_ast_node_types.return_value = [
|
||||
"FunctionDef",
|
||||
"ClassDef",
|
||||
"Module",
|
||||
"Import",
|
||||
"Call",
|
||||
]
|
||||
kg.root_node_id = 0
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.playground_path = "mock/playground/path"
|
||||
return git_repo
|
||||
|
||||
|
||||
def test_issue_documentation_subgraph_basic_initialization(mock_kg, mock_git_repo):
|
||||
"""Test that IssueDocumentationSubgraph initializes correctly with basic components."""
|
||||
# Initialize fake model with empty responses
|
||||
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
|
||||
fake_base_model = FakeListChatWithToolsModel(responses=[])
|
||||
|
||||
# Initialize the subgraph with required parameters
|
||||
subgraph = IssueDocumentationSubgraph(
|
||||
advanced_model=fake_advanced_model,
|
||||
base_model=fake_base_model,
|
||||
kg=mock_kg,
|
||||
git_repo=mock_git_repo,
|
||||
repository_id=1,
|
||||
)
|
||||
|
||||
# Verify the subgraph was created
|
||||
assert subgraph.subgraph is not None
|
||||
@@ -0,0 +1,49 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.docker.base_container import BaseContainer
|
||||
from prometheus.git.git_repository import GitRepository
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.lang_graph.subgraphs.issue_question_subgraph import IssueQuestionSubgraph
|
||||
from tests.test_utils.util import FakeListChatWithToolsModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container():
|
||||
return Mock(spec=BaseContainer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kg():
|
||||
kg = Mock(spec=KnowledgeGraph)
|
||||
# Configure the mock to return a list of AST node types
|
||||
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
|
||||
kg.root_node_id = 0
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_repo():
|
||||
git_repo = Mock(spec=GitRepository)
|
||||
git_repo.playground_path = "mock/playground/path"
|
||||
return git_repo
|
||||
|
||||
|
||||
def test_issue_question_subgraph_basic_initialization(mock_container, mock_kg, mock_git_repo):
|
||||
"""Test that IssueQuestionSubgraph initializes correctly with basic components."""
|
||||
# Initialize fake model with empty responses
|
||||
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
|
||||
fake_base_model = FakeListChatWithToolsModel(responses=[])
|
||||
|
||||
# Initialize the subgraph with required parameters
|
||||
subgraph = IssueQuestionSubgraph(
|
||||
advanced_model=fake_advanced_model,
|
||||
base_model=fake_base_model,
|
||||
kg=mock_kg,
|
||||
git_repo=mock_git_repo,
|
||||
repository_id=1,
|
||||
)
|
||||
|
||||
# Verify the subgraph was created
|
||||
assert subgraph.subgraph is not None
|
||||
0
prometheus/tests/neo4j/__init__.py
Normal file
0
prometheus/tests/neo4j/__init__.py
Normal file
131
prometheus/tests/neo4j/test_knowledge_graph_handler.py
Normal file
131
prometheus/tests/neo4j/test_knowledge_graph_handler.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.app.services.neo4j_service import Neo4jService
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
|
||||
from tests.test_utils import test_project_paths
|
||||
from tests.test_utils.fixtures import ( # noqa: F401
|
||||
NEO4J_PASSWORD,
|
||||
NEO4J_USERNAME,
|
||||
empty_neo4j_container_fixture,
|
||||
neo4j_container_with_kg_fixture,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811
|
||||
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
|
||||
neo4j_container, kg = neo4j_container_with_kg_fixture
|
||||
service = Neo4jService(neo4j_container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
|
||||
service.start()
|
||||
yield service
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_empty_neo4j_service(empty_neo4j_container_fixture): # noqa: F811
|
||||
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
|
||||
neo4j_container = empty_neo4j_container_fixture
|
||||
service = Neo4jService(neo4j_container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
|
||||
service.start()
|
||||
yield service
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_ast_nodes(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_ast_nodes = await session.execute_read(handler._read_ast_nodes, root_node_id=0)
|
||||
assert len(read_ast_nodes) == 7
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_file_nodes(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_file_nodes = await session.execute_read(handler._read_file_nodes, root_node_id=0)
|
||||
assert len(read_file_nodes) == 7
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_text_nodes(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_text_nodes = await session.execute_read(handler._read_text_nodes, root_node_id=0)
|
||||
assert len(read_text_nodes) == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_parent_of_edges(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_parent_of_edges = await session.execute_read(
|
||||
handler._read_parent_of_edges, root_node_id=0
|
||||
)
|
||||
assert len(read_parent_of_edges) == 4
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_has_file_edges(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_has_file_edges = await session.execute_read(
|
||||
handler._read_has_file_edges, root_node_id=0
|
||||
)
|
||||
assert len(read_has_file_edges) == 6
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_has_ast_edges(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_has_ast_edges = await session.execute_read(handler._read_has_ast_edges, root_node_id=0)
|
||||
assert len(read_has_ast_edges) == 3
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_has_text_edges(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_has_text_edges = await session.execute_read(
|
||||
handler._read_has_text_edges, root_node_id=0
|
||||
)
|
||||
assert len(read_has_text_edges) == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_num_next_chunk_edges(mock_neo4j_service):
|
||||
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
|
||||
|
||||
async with mock_neo4j_service.neo4j_driver.session() as session:
|
||||
read_next_chunk_edges = await session.execute_read(
|
||||
handler._read_next_chunk_edges, root_node_id=0
|
||||
)
|
||||
assert len(read_next_chunk_edges) == 0
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_clear_knowledge_graph(mock_empty_neo4j_service):
|
||||
kg = KnowledgeGraph(1, 1000, 100, 0)
|
||||
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
|
||||
driver = mock_empty_neo4j_service.neo4j_driver
|
||||
handler = KnowledgeGraphHandler(driver, 100)
|
||||
await handler.write_knowledge_graph(kg)
|
||||
|
||||
await handler.clear_knowledge_graph(0)
|
||||
|
||||
# Verify that the graph is cleared
|
||||
async with driver.session() as session:
|
||||
result = await session.run("MATCH (n) RETURN COUNT(n) AS node_count")
|
||||
record = await result.single()
|
||||
node_count = record["node_count"]
|
||||
assert node_count == 0
|
||||
0
prometheus/tests/parser/__init__.py
Normal file
0
prometheus/tests/parser/__init__.py
Normal file
34
prometheus/tests/parser/test_file_types.py
Normal file
34
prometheus/tests/parser/test_file_types.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.parser.file_types import FileType
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"file_path,expected_type",
|
||||
[
|
||||
("script.sh", FileType.BASH),
|
||||
("program.c", FileType.C),
|
||||
("app.cs", FileType.CSHARP),
|
||||
("class.cpp", FileType.CPP),
|
||||
("header.cc", FileType.CPP),
|
||||
("file.cxx", FileType.CPP),
|
||||
("main.go", FileType.GO),
|
||||
("Class.java", FileType.JAVA),
|
||||
("app.js", FileType.JAVASCRIPT),
|
||||
("Service.kt", FileType.KOTLIN),
|
||||
("index.php", FileType.PHP),
|
||||
("script.py", FileType.PYTHON),
|
||||
("query.sql", FileType.SQL),
|
||||
("config.yaml", FileType.YAML),
|
||||
("docker-compose.yml", FileType.YAML),
|
||||
("readme.md", FileType.UNKNOWN),
|
||||
("Makefile", FileType.UNKNOWN),
|
||||
("", FileType.UNKNOWN),
|
||||
],
|
||||
)
|
||||
def test_file_type_from_path(file_path: str, expected_type: FileType):
|
||||
"""Test that file extensions are correctly mapped to FileTypes"""
|
||||
path = Path(file_path)
|
||||
assert FileType.from_path(path) == expected_type
|
||||
68
prometheus/tests/parser/test_tree_sitter_parser.py
Normal file
68
prometheus/tests/parser/test_tree_sitter_parser.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, create_autospec, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from tree_sitter._binding import Tree
|
||||
|
||||
from prometheus.parser.file_types import FileType
|
||||
from prometheus.parser.tree_sitter_parser import (
|
||||
FILE_TYPE_TO_LANG,
|
||||
parse,
|
||||
supports_file,
|
||||
)
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def mock_python_file():
|
||||
return Path("test.py")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tree():
|
||||
tree = create_autospec(Tree, instance=True)
|
||||
return tree
|
||||
|
||||
|
||||
# Test supports_file function
|
||||
def test_supports_file_with_supported_type(mock_python_file):
|
||||
with patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path:
|
||||
mock_from_path.return_value = FileType.PYTHON
|
||||
assert supports_file(mock_python_file) is True
|
||||
mock_from_path.assert_called_once_with(mock_python_file)
|
||||
|
||||
|
||||
def test_parse_python_file_successfully(mock_python_file, mock_tree):
|
||||
mock_content = b'print("hello")'
|
||||
m = mock_open(read_data=mock_content)
|
||||
|
||||
with (
|
||||
patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path,
|
||||
patch("prometheus.parser.tree_sitter_parser.get_parser") as mock_get_parser,
|
||||
patch.object(Path, "open", m),
|
||||
):
|
||||
# Setup mocks
|
||||
mock_from_path.return_value = FileType.PYTHON
|
||||
mock_parser = mock_get_parser.return_value
|
||||
mock_parser.parse.return_value = mock_tree
|
||||
|
||||
# Test parse function
|
||||
result = parse(mock_python_file)
|
||||
|
||||
# Verify results and interactions
|
||||
assert result == mock_tree
|
||||
mock_from_path.assert_called_once_with(mock_python_file)
|
||||
mock_get_parser.assert_called_once_with("python")
|
||||
mock_parser.parse.assert_called_once_with(mock_content)
|
||||
|
||||
|
||||
def test_parse_all_supported_languages():
|
||||
"""Test that we can get parsers for all supported languages."""
|
||||
with patch("tree_sitter_language_pack.get_parser") as mock_get_parser:
|
||||
mock_parser = MagicMock()
|
||||
mock_get_parser.return_value = mock_parser
|
||||
|
||||
for lang in FILE_TYPE_TO_LANG.values():
|
||||
mock_get_parser(lang)
|
||||
mock_get_parser.assert_called_with(lang)
|
||||
mock_get_parser.reset_mock()
|
||||
1
prometheus/tests/test_project/.gitignore
vendored
Normal file
1
prometheus/tests/test_project/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.dummy_git/
|
||||
5
prometheus/tests/test_project/bar/test.java
Normal file
5
prometheus/tests/test_project/bar/test.java
Normal file
@@ -0,0 +1,5 @@
|
||||
public class test {
|
||||
public static void main(String[] args) {
|
||||
System.out.println("Hello world!");
|
||||
}
|
||||
}
|
||||
1
prometheus/tests/test_project/bar/test.py
Normal file
1
prometheus/tests/test_project/bar/test.py
Normal file
@@ -0,0 +1 @@
|
||||
print("Hello world!")
|
||||
11
prometheus/tests/test_project/foo/test.dummy
Normal file
11
prometheus/tests/test_project/foo/test.dummy
Normal file
@@ -0,0 +1,11 @@
|
||||
@program Start
|
||||
|
||||
use library console_output
|
||||
|
||||
start {
|
||||
declare text message = "Hello world!";
|
||||
print_to_console(message);
|
||||
halt;
|
||||
}
|
||||
|
||||
end @program
|
||||
15
prometheus/tests/test_project/foo/test.md
Normal file
15
prometheus/tests/test_project/foo/test.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# A
|
||||
|
||||
Text under header A.
|
||||
|
||||
## B
|
||||
|
||||
Text under header B.
|
||||
|
||||
## C
|
||||
|
||||
Text under header C.
|
||||
|
||||
### D
|
||||
|
||||
Text under header D.
|
||||
6
prometheus/tests/test_project/test.c
Normal file
6
prometheus/tests/test_project/test.c
Normal file
@@ -0,0 +1,6 @@
|
||||
#include <stdio.h>
|
||||
|
||||
int main() {
|
||||
printf("Hello world!");
|
||||
return 0;
|
||||
}
|
||||
0
prometheus/tests/test_utils/__init__.py
Normal file
0
prometheus/tests/test_utils/__init__.py
Normal file
90
prometheus/tests/test_utils/fixtures.py
Normal file
90
prometheus/tests/test_utils/fixtures.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from git import Repo
|
||||
from testcontainers.neo4j import Neo4jContainer
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
|
||||
from prometheus.app.services.neo4j_service import Neo4jService
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
|
||||
from tests.test_utils import test_project_paths
|
||||
|
||||
NEO4J_IMAGE = "neo4j:5.20.0"
|
||||
NEO4J_USERNAME = "neo4j"
|
||||
NEO4J_PASSWORD = "password"
|
||||
|
||||
POSTGRES_IMAGE = "postgres"
|
||||
POSTGRES_USERNAME = "postgres"
|
||||
POSTGRES_PASSWORD = "password"
|
||||
POSTGRES_DB = "postgres"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def neo4j_container_with_kg_fixture():
|
||||
kg = KnowledgeGraph(1, 1000, 100, 0)
|
||||
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
container = (
|
||||
Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)
|
||||
.with_env("NEO4J_PLUGINS", '["apoc"]')
|
||||
.with_name(f"neo4j_container_with_kg_{uuid.uuid4().hex[:12]}")
|
||||
)
|
||||
with container as neo4j_container:
|
||||
neo4j_service = Neo4jService(container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
|
||||
handler = KnowledgeGraphHandler(neo4j_service.neo4j_driver, 100)
|
||||
await handler.write_knowledge_graph(kg)
|
||||
yield neo4j_container, kg
|
||||
await neo4j_service.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def empty_neo4j_container_fixture():
|
||||
container = (
|
||||
Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)
|
||||
.with_env("NEO4J_PLUGINS", '["apoc"]')
|
||||
.with_name(f"empty_neo4j_container_{uuid.uuid4().hex[:12]}")
|
||||
)
|
||||
with container as neo4j_container:
|
||||
yield neo4j_container
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_container_fixture():
|
||||
container = PostgresContainer(
|
||||
image=POSTGRES_IMAGE,
|
||||
username=POSTGRES_USERNAME,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_DB,
|
||||
port=5432,
|
||||
driver="asyncpg",
|
||||
).with_name(f"postgres_container_{uuid.uuid4().hex[:12]}")
|
||||
with container as postgres_container:
|
||||
yield postgres_container
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def git_repo_fixture():
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
temp_project_dir = temp_dir / "test_project"
|
||||
original_project_path = test_project_paths.TEST_PROJECT_PATH
|
||||
|
||||
try:
|
||||
shutil.copytree(original_project_path, temp_project_dir)
|
||||
shutil.move(temp_project_dir / test_project_paths.GIT_DIR.name, temp_project_dir / ".git")
|
||||
|
||||
repo = Repo(temp_project_dir)
|
||||
yield repo
|
||||
finally:
|
||||
shutil.rmtree(temp_project_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_test_dir(tmp_path):
|
||||
"""Create a temporary test directory."""
|
||||
test_dir = tmp_path / "test_files"
|
||||
test_dir.mkdir()
|
||||
yield test_dir
|
||||
# Cleanup happens automatically after tests due to tmp_path fixture
|
||||
11
prometheus/tests/test_utils/test_project_paths.py
Normal file
11
prometheus/tests/test_utils/test_project_paths.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from pathlib import Path
|
||||
|
||||
TEST_PROJECT_PATH = Path(__file__).parent.parent / "test_project"
|
||||
GIT_DIR = TEST_PROJECT_PATH / ".dummy_git"
|
||||
C_FILE = TEST_PROJECT_PATH / "test.c"
|
||||
BAR_DIR = TEST_PROJECT_PATH / "bar"
|
||||
JAVA_FILE = BAR_DIR / "test.java"
|
||||
PYTHON_FILE = BAR_DIR / "test.py"
|
||||
FOO_DIR = TEST_PROJECT_PATH / "foo"
|
||||
MD_FILE = FOO_DIR / "test.md"
|
||||
DUMMY_FILE = FOO_DIR / "test.dummy"
|
||||
13
prometheus/tests/test_utils/util.py
Normal file
13
prometheus/tests/test_utils/util.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
||||
from testcontainers.neo4j import Neo4jContainer
|
||||
|
||||
|
||||
class FakeListChatWithToolsModel(FakeListChatModel):
|
||||
def bind_tools(self, tools=None, tool_choice=None, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def clean_neo4j_container(neo4j_container: Neo4jContainer):
|
||||
with neo4j_container.get_driver() as driver:
|
||||
with driver.session() as session:
|
||||
session.run("MATCH (n) DETACH DELETE n")
|
||||
0
prometheus/tests/tools/__init__.py
Normal file
0
prometheus/tests/tools/__init__.py
Normal file
141
prometheus/tests/tools/test_file_operation.py
Normal file
141
prometheus/tests/tools/test_file_operation.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.tools.file_operation import FileOperationTool
|
||||
from tests.test_utils import test_project_paths
|
||||
from tests.test_utils.fixtures import temp_test_dir # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def knowledge_graph_fixture(temp_test_dir): # noqa: F811
|
||||
if temp_test_dir.exists():
|
||||
shutil.rmtree(temp_test_dir)
|
||||
shutil.copytree(test_project_paths.TEST_PROJECT_PATH, temp_test_dir)
|
||||
kg = KnowledgeGraph(1, 1000, 100, 0)
|
||||
await kg.build_graph(temp_test_dir)
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def file_operation_tool(temp_test_dir, knowledge_graph_fixture): # noqa: F811
|
||||
file_operation = FileOperationTool(temp_test_dir, knowledge_graph_fixture)
|
||||
return file_operation
|
||||
|
||||
|
||||
def test_read_file_with_knowledge_graph_data(file_operation_tool):
|
||||
relative_path = str(
|
||||
test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
|
||||
)
|
||||
result = file_operation_tool.read_file_with_knowledge_graph_data(relative_path)
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "preview" in result_row
|
||||
assert 'print("Hello world!")' in result_row["preview"].get("text", "")
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("relative_path", "") == relative_path
|
||||
|
||||
|
||||
def test_create_and_read_file(temp_test_dir, file_operation_tool): # noqa: F811
|
||||
"""Test creating a file and reading its contents."""
|
||||
test_file = temp_test_dir / "test.txt"
|
||||
content = "line 1\nline 2\nline 3"
|
||||
|
||||
# Test create_file
|
||||
result = file_operation_tool.create_file("test.txt", content)
|
||||
assert test_file.exists()
|
||||
assert test_file.read_text() == content
|
||||
assert result == "The file test.txt has been created."
|
||||
|
||||
# Test read_file
|
||||
result = file_operation_tool.read_file("test.txt")
|
||||
expected = "1. line 1\n2. line 2\n3. line 3"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_read_file_nonexistent(file_operation_tool):
|
||||
"""Test reading a nonexistent file."""
|
||||
result = file_operation_tool.read_file("nonexistent_file.txt")
|
||||
assert result == "The file nonexistent_file.txt does not exist."
|
||||
|
||||
|
||||
def test_read_file_with_line_numbers(file_operation_tool):
|
||||
"""Test reading specific line ranges from a file."""
|
||||
content = "line 1\nline 2\nline 3\nline 4\nline 5"
|
||||
file_operation_tool.create_file("test_lines.txt", content)
|
||||
|
||||
# Test reading specific lines
|
||||
result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 2, 4)
|
||||
expected = "2. line 2\n3. line 3"
|
||||
assert result == expected
|
||||
|
||||
# Test invalid range
|
||||
result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 4, 2)
|
||||
assert result == "The end line number 2 must be greater than the start line number 4."
|
||||
|
||||
|
||||
def test_delete(file_operation_tool, temp_test_dir): # noqa: F811
|
||||
"""Test file and directory deletion."""
|
||||
# Test file deletion
|
||||
test_file = temp_test_dir / "to_delete.txt"
|
||||
file_operation_tool.create_file("to_delete.txt", "content")
|
||||
assert test_file.exists()
|
||||
result = file_operation_tool.delete("to_delete.txt")
|
||||
assert result == "The file to_delete.txt has been deleted."
|
||||
assert not test_file.exists()
|
||||
|
||||
# Test directory deletion
|
||||
test_subdir = temp_test_dir / "subdir"
|
||||
test_subdir.mkdir()
|
||||
file_operation_tool.create_file("subdir/file.txt", "content")
|
||||
result = file_operation_tool.delete("subdir")
|
||||
assert result == "The directory subdir has been deleted."
|
||||
assert not test_subdir.exists()
|
||||
|
||||
|
||||
def test_delete_nonexistent(file_operation_tool):
|
||||
"""Test deleting a nonexistent path."""
|
||||
result = file_operation_tool.delete("nonexistent_path")
|
||||
assert result == "The file nonexistent_path does not exist."
|
||||
|
||||
|
||||
def test_edit_file(file_operation_tool):
|
||||
"""Test editing specific lines in a file."""
|
||||
# Test case 1: Successfully edit a single occurrence
|
||||
initial_content = "line 1\nline 2\nline 3\nline 4\nline 5"
|
||||
file_operation_tool.create_file("edit_test.txt", initial_content)
|
||||
result = file_operation_tool.edit_file("edit_test.txt", "line 2", "new line 2")
|
||||
assert result == "Successfully edited edit_test.txt."
|
||||
|
||||
# Test case 2: Absolute path error
|
||||
result = file_operation_tool.edit_file("/edit_test.txt", "line 2", "new line 2")
|
||||
assert result == "relative_path: /edit_test.txt is a absolute path, not relative path."
|
||||
|
||||
# Test case 3: File doesn't exist
|
||||
result = file_operation_tool.edit_file("nonexistent.txt", "line 2", "new line 2")
|
||||
assert result == "The file nonexistent.txt does not exist."
|
||||
|
||||
# Test case 4: No matches found
|
||||
result = file_operation_tool.edit_file("edit_test.txt", "nonexistent line", "new content")
|
||||
assert (
|
||||
result
|
||||
== "No match found for the specified content in edit_test.txt. Please verify the content to replace."
|
||||
)
|
||||
|
||||
# Test case 5: Multiple occurrences
|
||||
duplicate_content = "line 1\nline 2\nline 2\nline 3"
|
||||
file_operation_tool.create_file("duplicate_test.txt", duplicate_content)
|
||||
result = file_operation_tool.edit_file("duplicate_test.txt", "line 2", "new line 2")
|
||||
assert (
|
||||
result
|
||||
== "Found 2 occurrences of the specified content in duplicate_test.txt. Please provide more context to ensure a unique match."
|
||||
)
|
||||
|
||||
|
||||
def test_create_file_already_exists(file_operation_tool):
|
||||
"""Test creating a file that already exists."""
|
||||
file_operation_tool.create_file("existing.txt", "content")
|
||||
result = file_operation_tool.create_file("existing.txt", "new content")
|
||||
assert result == "The file existing.txt already exists."
|
||||
189
prometheus/tests/tools/test_graph_traversal.py
Normal file
189
prometheus/tests/tools/test_graph_traversal.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.graph.knowledge_graph import KnowledgeGraph
|
||||
from prometheus.tools.graph_traversal import GraphTraversalTool
|
||||
from tests.test_utils import test_project_paths
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def knowledge_graph_fixture():
|
||||
kg = KnowledgeGraph(1000, 100, 10, 0)
|
||||
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
|
||||
return kg
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_traversal_tool(knowledge_graph_fixture):
|
||||
graph_traversal_tool = GraphTraversalTool(knowledge_graph_fixture)
|
||||
return graph_traversal_tool
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_file_node_with_basename(graph_traversal_tool):
|
||||
result = graph_traversal_tool.find_file_node_with_basename(test_project_paths.PYTHON_FILE.name)
|
||||
|
||||
basename = test_project_paths.PYTHON_FILE.name
|
||||
relative_path = str(
|
||||
test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
|
||||
)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) == 1
|
||||
assert "FileNode" in result_data[0]
|
||||
assert result_data[0]["FileNode"].get("basename", "") == basename
|
||||
assert result_data[0]["FileNode"].get("relative_path", "") == relative_path
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_file_node_with_relative_path(graph_traversal_tool):
|
||||
relative_path = str(
|
||||
test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
|
||||
)
|
||||
result = graph_traversal_tool.find_file_node_with_relative_path(relative_path)
|
||||
|
||||
basename = test_project_paths.MD_FILE.name
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) == 1
|
||||
assert "FileNode" in result_data[0]
|
||||
assert result_data[0]["FileNode"].get("basename", "") == basename
|
||||
assert result_data[0]["FileNode"].get("relative_path", "") == relative_path
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_ast_node_with_text_in_file_with_basename(graph_traversal_tool):
|
||||
basename = test_project_paths.PYTHON_FILE.name
|
||||
result = graph_traversal_tool.find_ast_node_with_text_in_file_with_basename(
|
||||
"Hello world!", basename
|
||||
)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "ASTNode" in result_row
|
||||
assert "Hello world!" in result_row["ASTNode"].get("text", "")
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("basename", "") == basename
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_ast_node_with_text_in_file_with_relative_path(graph_traversal_tool):
|
||||
relative_path = str(
|
||||
test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
|
||||
)
|
||||
result = graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path(
|
||||
"Hello world!", relative_path
|
||||
)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "ASTNode" in result_row
|
||||
assert "Hello world!" in result_row["ASTNode"].get("text", "")
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("relative_path", "") == relative_path
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_ast_node_with_type_in_file_with_basename(graph_traversal_tool):
|
||||
basename = test_project_paths.C_FILE.name
|
||||
node_type = "function_definition"
|
||||
result = graph_traversal_tool.find_ast_node_with_type_in_file_with_basename(node_type, basename)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "ASTNode" in result_row
|
||||
assert result_row["ASTNode"].get("type", "") == node_type
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("basename", "") == basename
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_ast_node_with_type_in_file_with_relative_path(graph_traversal_tool):
|
||||
relative_path = str(
|
||||
test_project_paths.JAVA_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
|
||||
)
|
||||
node_type = "string_literal"
|
||||
result = graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path(
|
||||
node_type, relative_path
|
||||
)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "ASTNode" in result_row
|
||||
assert result_row["ASTNode"].get("type", "") == node_type
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("relative_path", "") == relative_path
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_text_node_with_text(graph_traversal_tool):
|
||||
text = "Text under header C"
|
||||
result = graph_traversal_tool.find_text_node_with_text(text)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "TextNode" in result_row
|
||||
assert text in result_row["TextNode"].get("text", "")
|
||||
assert "start_line" in result_row["TextNode"]
|
||||
assert result_row["TextNode"]["start_line"] == 1
|
||||
assert "end_line" in result_row["TextNode"]
|
||||
assert result_row["TextNode"]["end_line"] == 13
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("relative_path", "") == "foo/test.md"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_find_text_node_with_text_in_file(graph_traversal_tool):
|
||||
basename = test_project_paths.MD_FILE.name
|
||||
text = "Text under header B"
|
||||
result = graph_traversal_tool.find_text_node_with_text_in_file(text, basename)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "TextNode" in result_row
|
||||
assert text in result_row["TextNode"].get("text", "")
|
||||
assert "start_line" in result_row["TextNode"]
|
||||
assert result_row["TextNode"]["start_line"] == 1
|
||||
assert "end_line" in result_row["TextNode"]
|
||||
assert result_row["TextNode"]["end_line"] == 13
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("basename", "") == basename
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_get_next_text_node_with_node_id(graph_traversal_tool):
|
||||
node_id = 34
|
||||
result = graph_traversal_tool.get_next_text_node_with_node_id(node_id)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "TextNode" in result_row
|
||||
assert "Text under header D" in result_row["TextNode"].get("text", "")
|
||||
assert "start_line" in result_row["TextNode"]
|
||||
assert result_row["TextNode"]["start_line"] == 13
|
||||
assert "end_line" in result_row["TextNode"]
|
||||
assert result_row["TextNode"]["end_line"] == 15
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("relative_path", "") == "foo/test.md"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
async def test_read_code_with_relative_path(graph_traversal_tool):
|
||||
relative_path = str(
|
||||
test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
|
||||
)
|
||||
result = graph_traversal_tool.read_code_with_relative_path(relative_path, 5, 6)
|
||||
|
||||
result_data = result[1]
|
||||
assert len(result_data) > 0
|
||||
for result_row in result_data:
|
||||
assert "SelectedLines" in result_row
|
||||
assert "return 0;" in result_row["SelectedLines"].get("text", "")
|
||||
assert "FileNode" in result_row
|
||||
assert result_row["FileNode"].get("relative_path", "") == relative_path
|
||||
551
prometheus/tests/tools/test_web_search.py
Normal file
551
prometheus/tests/tools/test_web_search.py
Normal file
@@ -0,0 +1,551 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from tavily import InvalidAPIKeyError, UsageLimitExceededError
|
||||
|
||||
from prometheus.exceptions.web_search_tool_exception import WebSearchToolException
|
||||
from prometheus.tools.web_search import WebSearchInput, WebSearchTool, format_results
|
||||
|
||||
|
||||
class TestFormatResults:
|
||||
"""Test suite for format_results function."""
|
||||
|
||||
def test_format_results_basic(self):
|
||||
"""Test basic formatting of search results without answer."""
|
||||
response = {
|
||||
"results": [
|
||||
{
|
||||
"title": "How to fix Python import error",
|
||||
"url": "https://stackoverflow.com/questions/12345",
|
||||
"content": "This is how you fix import errors in Python...",
|
||||
},
|
||||
{
|
||||
"title": "Python Documentation",
|
||||
"url": "https://docs.python.org/3/",
|
||||
"content": "Official Python documentation...",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "Detailed Results:" in result
|
||||
assert "How to fix Python import error" in result
|
||||
assert "https://stackoverflow.com/questions/12345" in result
|
||||
assert "This is how you fix import errors in Python..." in result
|
||||
assert "Python Documentation" in result
|
||||
assert "https://docs.python.org/3/" in result
|
||||
|
||||
def test_format_results_with_answer(self):
|
||||
"""Test formatting with answer included."""
|
||||
response = {
|
||||
"answer": "To fix import errors, check your PYTHONPATH and ensure modules are installed.",
|
||||
"results": [
|
||||
{
|
||||
"title": "Python Import Guide",
|
||||
"url": "https://docs.python.org/import",
|
||||
"content": "Guide to Python imports...",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "Answer:" in result
|
||||
assert "To fix import errors" in result
|
||||
assert "Sources:" in result
|
||||
assert "Python Import Guide" in result
|
||||
assert "https://docs.python.org/import" in result
|
||||
assert "Detailed Results:" in result
|
||||
|
||||
def test_format_results_with_published_date(self):
|
||||
"""Test formatting with published date."""
|
||||
response = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Article Title",
|
||||
"url": "https://example.com/article",
|
||||
"content": "Article content...",
|
||||
"published_date": "2024-01-15",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "Published: 2024-01-15" in result
|
||||
|
||||
def test_format_results_with_included_domains(self):
|
||||
"""Test formatting with included domains filter."""
|
||||
response = {
|
||||
"included_domains": ["stackoverflow.com", "github.com"],
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Result",
|
||||
"url": "https://stackoverflow.com/test",
|
||||
"content": "Test content",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "Search Filters:" in result
|
||||
assert "Including domains: stackoverflow.com, github.com" in result
|
||||
|
||||
def test_format_results_with_excluded_domains(self):
|
||||
"""Test formatting with excluded domains filter."""
|
||||
response = {
|
||||
"excluded_domains": ["pinterest.com", "reddit.com"],
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Result",
|
||||
"url": "https://example.com/test",
|
||||
"content": "Test content",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "Search Filters:" in result
|
||||
assert "Excluding domains: pinterest.com, reddit.com" in result
|
||||
|
||||
def test_format_results_with_both_domain_filters(self):
|
||||
"""Test formatting with both included and excluded domains."""
|
||||
response = {
|
||||
"included_domains": ["stackoverflow.com"],
|
||||
"excluded_domains": ["pinterest.com"],
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Result",
|
||||
"url": "https://stackoverflow.com/test",
|
||||
"content": "Test content",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "Including domains: stackoverflow.com" in result
|
||||
assert "Excluding domains: pinterest.com" in result
|
||||
|
||||
def test_format_results_empty(self):
|
||||
"""Test formatting with no results."""
|
||||
response = {"results": []}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "Detailed Results:" in result
|
||||
# Should not contain any result-specific content
|
||||
assert "Title:" not in result
|
||||
assert "URL:" not in result
|
||||
|
||||
def test_format_results_multiple_results(self):
|
||||
"""Test formatting with multiple results."""
|
||||
response = {
|
||||
"results": [
|
||||
{
|
||||
"title": f"Result {i}",
|
||||
"url": f"https://example.com/{i}",
|
||||
"content": f"Content {i}",
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
for i in range(5):
|
||||
assert f"Result {i}" in result
|
||||
assert f"https://example.com/{i}" in result
|
||||
assert f"Content {i}" in result
|
||||
|
||||
def test_format_results_special_characters(self):
|
||||
"""Test formatting with special characters in content."""
|
||||
response = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Special chars: @#$%^&*()",
|
||||
"url": "https://example.com/special",
|
||||
"content": "Content with 中文 and emojis 🚀",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
assert "@#$%^&*()" in result
|
||||
assert "中文" in result
|
||||
assert "🚀" in result
|
||||
|
||||
|
||||
class TestWebSearchInput:
|
||||
"""Test suite for WebSearchInput model."""
|
||||
|
||||
def test_web_search_input_valid(self):
|
||||
"""Test valid WebSearchInput creation."""
|
||||
input_data = WebSearchInput(query="Python import error")
|
||||
assert input_data.query == "Python import error"
|
||||
|
||||
def test_web_search_input_empty_query(self):
|
||||
"""Test WebSearchInput with empty query."""
|
||||
input_data = WebSearchInput(query="")
|
||||
assert input_data.query == ""
|
||||
|
||||
def test_web_search_input_long_query(self):
|
||||
"""Test WebSearchInput with long query."""
|
||||
long_query = "A" * 1000
|
||||
input_data = WebSearchInput(query=long_query)
|
||||
assert input_data.query == long_query
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
"""Test suite for WebSearchTool class."""
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_init_with_api_key(self, mock_tavily_client, mock_settings):
|
||||
"""Test WebSearchTool initialization with API key."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
tool = WebSearchTool()
|
||||
|
||||
mock_tavily_client.assert_called_once_with(api_key="test_api_key")
|
||||
assert tool.tavily_client is not None
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_init_without_api_key(self, mock_tavily_client, mock_settings):
|
||||
"""Test WebSearchTool initialization without API key."""
|
||||
mock_settings.TAVILY_API_KEY = None
|
||||
|
||||
tool = WebSearchTool()
|
||||
|
||||
mock_tavily_client.assert_not_called()
|
||||
assert tool.tavily_client is None
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_success(self, mock_tavily_client, mock_settings):
|
||||
"""Test successful web search."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
# Setup mock client
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
|
||||
mock_response = {
|
||||
"answer": "Test answer",
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Result",
|
||||
"url": "https://stackoverflow.com/test",
|
||||
"content": "Test content",
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
# Create tool and search
|
||||
tool = WebSearchTool()
|
||||
result = tool.web_search(query="Python import error")
|
||||
|
||||
# Verify search was called with correct parameters
|
||||
mock_client_instance.search.assert_called_once()
|
||||
call_kwargs = mock_client_instance.search.call_args[1]
|
||||
assert call_kwargs["query"] == "Python import error"
|
||||
assert call_kwargs["max_results"] == 5
|
||||
assert call_kwargs["search_depth"] == "advanced"
|
||||
assert call_kwargs["include_answer"] is True
|
||||
assert "stackoverflow.com" in call_kwargs["include_domains"]
|
||||
assert "github.com" in call_kwargs["include_domains"]
|
||||
|
||||
# Verify result formatting
|
||||
assert "Test answer" in result
|
||||
assert "Test Result" in result
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_with_custom_params(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with custom parameters."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
|
||||
mock_response = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Custom Result",
|
||||
"url": "https://example.com/test",
|
||||
"content": "Custom content",
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
tool = WebSearchTool()
|
||||
tool.web_search(
|
||||
query="test query",
|
||||
max_results=10,
|
||||
include_domains=["custom-domain.com"],
|
||||
exclude_domains=["excluded.com"],
|
||||
)
|
||||
|
||||
# Verify custom parameters were used
|
||||
call_kwargs = mock_client_instance.search.call_args[1]
|
||||
assert call_kwargs["max_results"] == 10
|
||||
assert call_kwargs["include_domains"] == ["custom-domain.com"]
|
||||
assert call_kwargs["exclude_domains"] == ["excluded.com"]
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_default_domains(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search uses default domains when not specified."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
|
||||
mock_response = {"results": []}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
tool = WebSearchTool()
|
||||
tool.web_search(query="test")
|
||||
|
||||
call_kwargs = mock_client_instance.search.call_args[1]
|
||||
default_domains = call_kwargs["include_domains"]
|
||||
|
||||
# Verify default domains are present
|
||||
assert "stackoverflow.com" in default_domains
|
||||
assert "github.com" in default_domains
|
||||
assert "developer.mozilla.org" in default_domains
|
||||
assert "learn.microsoft.com" in default_domains
|
||||
assert "docs.python.org" in default_domains
|
||||
assert "pydantic.dev" in default_domains
|
||||
assert "pypi.org" in default_domains
|
||||
assert "readthedocs.org" in default_domains
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_invalid_api_key(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with invalid API key."""
|
||||
mock_settings.TAVILY_API_KEY = "invalid_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
mock_client_instance.search.side_effect = InvalidAPIKeyError("Invalid API key")
|
||||
|
||||
tool = WebSearchTool()
|
||||
|
||||
with pytest.raises(WebSearchToolException) as exc_info:
|
||||
tool.web_search(query="test")
|
||||
|
||||
assert "Invalid Tavily API key" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_usage_limit_exceeded(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search when usage limit is exceeded."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
mock_client_instance.search.side_effect = UsageLimitExceededError("Limit exceeded")
|
||||
|
||||
tool = WebSearchTool()
|
||||
|
||||
with pytest.raises(WebSearchToolException) as exc_info:
|
||||
tool.web_search(query="test")
|
||||
|
||||
assert "Usage limit exceeded" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_generic_exception(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with generic exception."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
mock_client_instance.search.side_effect = Exception("Network error")
|
||||
|
||||
tool = WebSearchTool()
|
||||
|
||||
with pytest.raises(WebSearchToolException) as exc_info:
|
||||
tool.web_search(query="test")
|
||||
|
||||
assert "An error occurred: Network error" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_empty_query(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with empty query."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
|
||||
mock_response = {"results": []}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
tool = WebSearchTool()
|
||||
result = tool.web_search(query="")
|
||||
|
||||
# Should still call search even with empty query
|
||||
mock_client_instance.search.assert_called_once()
|
||||
assert "Detailed Results:" in result
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_with_none_exclude_domains(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with None exclude_domains."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
|
||||
mock_response = {"results": []}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
tool = WebSearchTool()
|
||||
tool.web_search(query="test", exclude_domains=None)
|
||||
|
||||
# Verify None is converted to empty list
|
||||
call_kwargs = mock_client_instance.search.call_args[1]
|
||||
assert call_kwargs["exclude_domains"] == []
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_complex_query(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with complex query containing special characters."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
|
||||
mock_response = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Result",
|
||||
"url": "https://example.com",
|
||||
"content": "Content",
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
tool = WebSearchTool()
|
||||
complex_query = 'Python "ModuleNotFoundError" 中文 @#$%'
|
||||
tool.web_search(query=complex_query)
|
||||
|
||||
# Verify complex query is passed correctly
|
||||
call_kwargs = mock_client_instance.search.call_args[1]
|
||||
assert call_kwargs["query"] == complex_query
|
||||
|
||||
|
||||
class TestToolSpec:
|
||||
"""Test suite for ToolSpec and tool configuration."""
|
||||
|
||||
def test_web_search_tool_spec_exists(self):
|
||||
"""Test that web_search_spec is properly defined."""
|
||||
assert hasattr(WebSearchTool, "web_search_spec")
|
||||
spec = WebSearchTool.web_search_spec
|
||||
|
||||
assert spec.description is not None
|
||||
assert len(spec.description) > 0
|
||||
assert spec.input_schema == WebSearchInput
|
||||
|
||||
def test_web_search_tool_spec_description(self):
|
||||
"""Test that web_search_spec description contains expected keywords."""
|
||||
spec = WebSearchTool.web_search_spec
|
||||
|
||||
# Verify description mentions key use cases
|
||||
assert "bug analysis" in spec.description
|
||||
assert "error messages" in spec.description
|
||||
assert "documentation" in spec.description
|
||||
assert "library" in spec.description
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and boundary conditions."""
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_max_results_boundary(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with boundary values for max_results."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
mock_response = {"results": []}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
tool = WebSearchTool()
|
||||
|
||||
# Test with 0
|
||||
tool.web_search(query="test", max_results=0)
|
||||
assert mock_client_instance.search.call_args[1]["max_results"] == 0
|
||||
|
||||
# Test with large number
|
||||
tool.web_search(query="test", max_results=100)
|
||||
assert mock_client_instance.search.call_args[1]["max_results"] == 100
|
||||
|
||||
def test_format_results_missing_optional_fields(self):
|
||||
"""Test formatting when optional fields are missing."""
|
||||
response = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Content",
|
||||
# No published_date
|
||||
}
|
||||
]
|
||||
# No answer, included_domains, excluded_domains
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
# Should not fail and should not include missing fields
|
||||
assert "Published:" not in result
|
||||
assert "Answer:" not in result
|
||||
assert "Search Filters:" not in result
|
||||
|
||||
@patch("prometheus.tools.web_search.settings")
|
||||
@patch("prometheus.tools.web_search.TavilyClient")
|
||||
def test_web_search_with_empty_domains_lists(self, mock_tavily_client, mock_settings):
|
||||
"""Test web search with empty domain lists."""
|
||||
mock_settings.TAVILY_API_KEY = "test_api_key"
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_tavily_client.return_value = mock_client_instance
|
||||
mock_response = {"results": []}
|
||||
mock_client_instance.search.return_value = mock_response
|
||||
|
||||
tool = WebSearchTool()
|
||||
tool.web_search(query="test", include_domains=[], exclude_domains=[])
|
||||
|
||||
call_kwargs = mock_client_instance.search.call_args[1]
|
||||
assert call_kwargs["include_domains"] == []
|
||||
assert call_kwargs["exclude_domains"] == []
|
||||
|
||||
def test_format_results_with_long_content(self):
|
||||
"""Test formatting with very long content."""
|
||||
long_content = "A" * 10000
|
||||
response = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Long Content Result",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = format_results(response)
|
||||
|
||||
# Should handle long content without errors
|
||||
assert long_content in result
|
||||
assert "Long Content Result" in result
|
||||
0
prometheus/tests/utils/__init__.py
Normal file
0
prometheus/tests/utils/__init__.py
Normal file
35
prometheus/tests/utils/test_file_utils.py
Normal file
35
prometheus/tests/utils/test_file_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from prometheus.exceptions.file_operation_exception import FileOperationException
|
||||
from prometheus.utils.file_utils import (
|
||||
read_file_with_line_numbers,
|
||||
)
|
||||
from tests.test_utils.test_project_paths import TEST_PROJECT_PATH
|
||||
|
||||
|
||||
def test_read_file_with_line_numbers():
|
||||
"""Test reading specific line ranges from a file."""
|
||||
# Test reading specific lines
|
||||
result = read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 1, 15)
|
||||
expected = """1. # A
|
||||
2.
|
||||
3. Text under header A.
|
||||
4.
|
||||
5. ## B
|
||||
6.
|
||||
7. Text under header B.
|
||||
8.
|
||||
9. ## C
|
||||
10.
|
||||
11. Text under header C.
|
||||
12.
|
||||
13. ### D
|
||||
14.
|
||||
15. Text under header D."""
|
||||
assert result == expected
|
||||
|
||||
# Test invalid range should raise exception
|
||||
with pytest.raises(FileOperationException) as exc_info:
|
||||
read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 4, 2)
|
||||
|
||||
assert str(exc_info.value) == "The end line number must be greater than the start line number."
|
||||
50
prometheus/tests/utils/test_issue_util.py
Normal file
50
prometheus/tests/utils/test_issue_util.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from prometheus.utils.issue_util import (
|
||||
format_issue_comments,
|
||||
format_issue_info,
|
||||
format_test_commands,
|
||||
)
|
||||
|
||||
|
||||
def test_format_issue_comments():
|
||||
comments = [
|
||||
{"username": "alice", "comment": "This looks good!"},
|
||||
{"username": "bob", "comment": "Can we add tests?"},
|
||||
]
|
||||
|
||||
result = format_issue_comments(comments)
|
||||
expected = "alice: This looks good!\n\nbob: Can we add tests?"
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_format_issue_info():
|
||||
title = "Bug in login flow"
|
||||
body = "Users can't login on mobile devices"
|
||||
comments = [
|
||||
{"username": "alice", "comment": "I can reproduce this"},
|
||||
{"username": "bob", "comment": "Working on a fix"},
|
||||
]
|
||||
|
||||
result = format_issue_info(title, body, comments)
|
||||
expected = """\
|
||||
Issue title:
|
||||
Bug in login flow
|
||||
|
||||
Issue description:
|
||||
Users can't login on mobile devices
|
||||
|
||||
Issue comments:
|
||||
alice: I can reproduce this
|
||||
|
||||
bob: Working on a fix"""
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_format_test_commands():
|
||||
commands = ["pytest test_login.py", "pytest test_auth.py -v"]
|
||||
|
||||
result = format_test_commands(commands)
|
||||
expected = "$ pytest test_login.py\n$ pytest test_auth.py -v"
|
||||
|
||||
assert result == expected
|
||||
260
prometheus/tests/utils/test_knowledge_graph_utils.py
Normal file
260
prometheus/tests/utils/test_knowledge_graph_utils.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from prometheus.models.context import Context
|
||||
from prometheus.utils.knowledge_graph_utils import (
|
||||
knowledge_graph_data_for_context_generator,
|
||||
sort_contexts,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_data():
|
||||
"""Test with None or empty data"""
|
||||
assert knowledge_graph_data_for_context_generator(None) == []
|
||||
assert knowledge_graph_data_for_context_generator([]) == []
|
||||
|
||||
|
||||
def test_skip_empty_content():
|
||||
"""Test that empty or whitespace-only content is skipped"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": " \n\t ", "start_line": 2, "end_line": 2},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "valid content", "start_line": 3, "end_line": 3},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "valid content"
|
||||
|
||||
|
||||
def test_deduplication_identical_content():
|
||||
"""Test deduplication of identical content"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "def hello():"
|
||||
|
||||
|
||||
def test_deduplication_content_containment():
|
||||
"""Test deduplication when one content contains another"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "def hello():\n print('world')", "start_line": 1, "end_line": 2},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": "print('world')", "start_line": 2, "end_line": 2},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "def hello():\n print('world')"
|
||||
|
||||
|
||||
def test_deduplication_line_containment():
|
||||
"""Test deduplication based on line number containment"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": "function body", "start_line": 1, "end_line": 5},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {"text": "inner content", "start_line": 2, "end_line": 3},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "function body"
|
||||
assert result[0].start_line_number == 1
|
||||
assert result[0].end_line_number == 5
|
||||
|
||||
|
||||
def test_different_files_no_deduplication():
|
||||
"""Test that identical content in different files is not deduplicated"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "file1.py"},
|
||||
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
{
|
||||
"FileNode": {"relative_path": "file2.py"},
|
||||
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 2
|
||||
assert result[0].relative_path == "file1.py"
|
||||
assert result[1].relative_path == "file2.py"
|
||||
|
||||
|
||||
def test_content_stripping():
|
||||
"""Test that content is properly stripped of whitespace"""
|
||||
data = [
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {"text": " \n def hello(): \n ", "start_line": 1, "end_line": 1},
|
||||
}
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "def hello():"
|
||||
|
||||
|
||||
def test_complex_deduplication_scenario():
|
||||
"""Test complex scenario with multiple overlapping contexts"""
|
||||
data = [
|
||||
# Large context containing everything
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"ASTNode": {
|
||||
"text": "class MyClass:\n def method1(self):\n pass\n def method2(self):\n pass",
|
||||
"start_line": 1,
|
||||
"end_line": 5,
|
||||
},
|
||||
},
|
||||
# Smaller context contained within the large one
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"TextNode": {
|
||||
"text": "def method1(self):\n pass",
|
||||
"start_line": 2,
|
||||
"end_line": 3,
|
||||
},
|
||||
},
|
||||
# Separate context in same file
|
||||
{
|
||||
"FileNode": {"relative_path": "test.py"},
|
||||
"SelectedLines": {"text": "# Comment at end", "start_line": 10, "end_line": 10},
|
||||
},
|
||||
]
|
||||
result = knowledge_graph_data_for_context_generator(data)
|
||||
assert len(result) == 2 # Large context + separate comment
|
||||
assert "class MyClass:" in result[0].content
|
||||
assert result[1].content == "# Comment at end"
|
||||
|
||||
|
||||
def test_sort_contexts_empty_list():
|
||||
"""Test sorting an empty list"""
|
||||
result = sort_contexts([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_sort_contexts_by_relative_path():
|
||||
"""Test sorting by relative path"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="src/z.py", content="content z", start_line_number=1, end_line_number=5
|
||||
),
|
||||
Context(
|
||||
relative_path="src/a.py", content="content a", start_line_number=1, end_line_number=5
|
||||
),
|
||||
Context(
|
||||
relative_path="src/m.py", content="content m", start_line_number=1, end_line_number=5
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].relative_path == "src/a.py"
|
||||
assert result[1].relative_path == "src/m.py"
|
||||
assert result[2].relative_path == "src/z.py"
|
||||
|
||||
|
||||
def test_sort_contexts_by_line_numbers():
|
||||
"""Test sorting by line numbers within same file"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="test.py", content="content 3", start_line_number=20, end_line_number=25
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 1", start_line_number=1, end_line_number=5
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 2", start_line_number=10, end_line_number=15
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].start_line_number == 1
|
||||
assert result[1].start_line_number == 10
|
||||
assert result[2].start_line_number == 20
|
||||
|
||||
|
||||
def test_sort_contexts_none_line_numbers():
|
||||
"""Test sorting when line numbers are None (should appear last)"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="test.py",
|
||||
content="content with lines",
|
||||
start_line_number=10,
|
||||
end_line_number=15,
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py",
|
||||
content="content no lines",
|
||||
start_line_number=None,
|
||||
end_line_number=None,
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content first", start_line_number=1, end_line_number=5
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].start_line_number == 1
|
||||
assert result[1].start_line_number == 10
|
||||
assert result[2].start_line_number is None
|
||||
|
||||
|
||||
def test_sort_contexts_mixed_files_and_lines():
|
||||
"""Test sorting with multiple files and different line numbers"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="b.py", content="b content 2", start_line_number=20, end_line_number=25
|
||||
),
|
||||
Context(
|
||||
relative_path="a.py", content="a content 2", start_line_number=10, end_line_number=15
|
||||
),
|
||||
Context(
|
||||
relative_path="b.py", content="b content 1", start_line_number=5, end_line_number=10
|
||||
),
|
||||
Context(
|
||||
relative_path="a.py", content="a content 1", start_line_number=1, end_line_number=5
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].relative_path == "a.py" and result[0].start_line_number == 1
|
||||
assert result[1].relative_path == "a.py" and result[1].start_line_number == 10
|
||||
assert result[2].relative_path == "b.py" and result[2].start_line_number == 5
|
||||
assert result[3].relative_path == "b.py" and result[3].start_line_number == 20
|
||||
|
||||
|
||||
def test_sort_contexts_end_line_number_tiebreaker():
|
||||
"""Test sorting uses end_line_number as tiebreaker when start_line_number is same"""
|
||||
contexts = [
|
||||
Context(
|
||||
relative_path="test.py", content="content 3", start_line_number=10, end_line_number=30
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 1", start_line_number=10, end_line_number=15
|
||||
),
|
||||
Context(
|
||||
relative_path="test.py", content="content 2", start_line_number=10, end_line_number=20
|
||||
),
|
||||
]
|
||||
result = sort_contexts(contexts)
|
||||
assert result[0].end_line_number == 15
|
||||
assert result[1].end_line_number == 20
|
||||
assert result[2].end_line_number == 30
|
||||
108
prometheus/tests/utils/test_lang_graph_util.py
Normal file
108
prometheus/tests/utils/test_lang_graph_util.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from prometheus.utils.lang_graph_util import (
|
||||
check_remaining_steps,
|
||||
extract_ai_responses,
|
||||
extract_human_queries,
|
||||
extract_last_tool_messages,
|
||||
format_agent_tool_message_history,
|
||||
get_last_message_content,
|
||||
)
|
||||
|
||||
|
||||
# Test check_remaining_steps
|
||||
def test_check_remaining_steps():
|
||||
def mock_router(state):
|
||||
return "next_step"
|
||||
|
||||
state_enough_steps = {"remaining_steps": 5}
|
||||
state_low_steps = {"remaining_steps": 2}
|
||||
|
||||
assert check_remaining_steps(state_enough_steps, mock_router, 3) == "next_step"
|
||||
assert check_remaining_steps(state_low_steps, mock_router, 3) == "low_remaining_steps"
|
||||
|
||||
|
||||
# Test extract_ai_responses
|
||||
def test_extract_ai_responses():
|
||||
messages = [
|
||||
HumanMessage(content="Human 1"),
|
||||
AIMessage(content="AI 1"),
|
||||
HumanMessage(content="Human 2"),
|
||||
AIMessage(content="AI 2"),
|
||||
]
|
||||
|
||||
responses = extract_ai_responses(messages)
|
||||
assert len(responses) == 2
|
||||
assert "AI 1" in responses
|
||||
assert "AI 2" in responses
|
||||
|
||||
|
||||
# Test extract_human_queries
|
||||
def test_extract_human_queries():
|
||||
messages = [
|
||||
SystemMessage(content="System"),
|
||||
HumanMessage(content="Human 1"),
|
||||
AIMessage(content="AI 1"),
|
||||
HumanMessage(content="Human 2"),
|
||||
]
|
||||
|
||||
queries = extract_human_queries(messages)
|
||||
assert len(queries) == 2
|
||||
assert "Human 1" in queries
|
||||
assert "Human 2" in queries
|
||||
|
||||
|
||||
# Test extract_last_tool_messages
|
||||
def test_extract_last_tool_messages():
|
||||
messages = [
|
||||
HumanMessage(content="Human 1"),
|
||||
ToolMessage(content="Tool 1", tool_call_id="call_1"),
|
||||
AIMessage(content="AI 1"),
|
||||
HumanMessage(content="Human 2"),
|
||||
ToolMessage(content="Tool 2", tool_call_id="call_2"),
|
||||
ToolMessage(content="Tool 3", tool_call_id="call_3"),
|
||||
]
|
||||
|
||||
tool_messages = extract_last_tool_messages(messages)
|
||||
assert len(tool_messages) == 2
|
||||
assert all(isinstance(msg, ToolMessage) for msg in tool_messages)
|
||||
assert tool_messages[-1].content == "Tool 3"
|
||||
|
||||
|
||||
# Test get_last_message_content
|
||||
def test_get_last_message_content():
|
||||
messages = [
|
||||
HumanMessage(content="Human"),
|
||||
AIMessage(content="AI"),
|
||||
ToolMessage(content="Last message", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
content = get_last_message_content(messages)
|
||||
assert content == "Last message"
|
||||
|
||||
|
||||
def test_format_agent_tool_message_history():
|
||||
messages = [
|
||||
AIMessage(content="Let me analyze this"),
|
||||
AIMessage(
|
||||
content="I'll use a tool for this",
|
||||
additional_kwargs={"tool_calls": [{"function": "analyze_data"}]},
|
||||
),
|
||||
ToolMessage(content="Analysis results: Success", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
result = format_agent_tool_message_history(messages)
|
||||
|
||||
expected = (
|
||||
"Assistant internal thought: Let me analyze this\n\n"
|
||||
"Assistant internal thought: I'll use a tool for this\n\n"
|
||||
"Assistant executed tool: analyze_data\n\n"
|
||||
"Tool output: Analysis results: Success"
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
391
prometheus/tests/utils/test_memory_utils.py
Normal file
391
prometheus/tests/utils/test_memory_utils.py
Normal file
@@ -0,0 +1,391 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from prometheus.models.context import Context
|
||||
from prometheus.models.query import Query
|
||||
from prometheus.utils.memory_utils import AthenaMemoryClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def athena_client():
|
||||
"""Create an AthenaMemoryClient instance for testing."""
|
||||
return AthenaMemoryClient(base_url="http://test-athena-service.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_contexts():
|
||||
"""Create sample Context objects for testing."""
|
||||
return [
|
||||
Context(
|
||||
relative_path="src/main.py",
|
||||
content="def main():\n print('Hello')",
|
||||
start_line_number=1,
|
||||
end_line_number=2,
|
||||
),
|
||||
Context(
|
||||
relative_path="src/utils.py",
|
||||
content="def helper():\n return True",
|
||||
start_line_number=10,
|
||||
end_line_number=11,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_query():
|
||||
"""Create a sample Query object for testing."""
|
||||
return Query(
|
||||
essential_query="How to implement authentication?",
|
||||
extra_requirements="Using JWT tokens",
|
||||
purpose="Feature implementation",
|
||||
)
|
||||
|
||||
|
||||
class TestAthenaMemoryClient:
|
||||
"""Test suite for AthenaMemoryClient class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test AthenaMemoryClient initialization."""
|
||||
client = AthenaMemoryClient(base_url="http://example.com/")
|
||||
assert client.base_url == "http://example.com"
|
||||
assert client.timeout == 30
|
||||
|
||||
def test_init_strips_trailing_slash(self):
|
||||
"""Test that trailing slashes are removed from base_url."""
|
||||
client = AthenaMemoryClient(base_url="http://example.com///")
|
||||
assert client.base_url == "http://example.com"
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_success(self, mock_post, athena_client, sample_contexts):
|
||||
"""Test successful memory storage."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success", "memory_id": "123"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call the method
|
||||
result = athena_client.store_memory(
|
||||
repository_id=42,
|
||||
essential_query="How to use async?",
|
||||
extra_requirements="Python 3.11",
|
||||
purpose="Learning",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result == {"status": "success", "memory_id": "123"}
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/store/"
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
# Verify payload structure
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["repository_id"] == 42
|
||||
assert payload["query"]["essential_query"] == "How to use async?"
|
||||
assert payload["query"]["extra_requirements"] == "Python 3.11"
|
||||
assert payload["query"]["purpose"] == "Learning"
|
||||
assert len(payload["contexts"]) == 2
|
||||
assert payload["contexts"][0]["relative_path"] == "src/main.py"
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_request_exception(self, mock_post, athena_client, sample_contexts):
|
||||
"""Test store_memory raises RequestException on failure."""
|
||||
# Setup mock to raise an exception
|
||||
mock_post.side_effect = requests.RequestException("Connection error")
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.RequestException) as exc_info:
|
||||
athena_client.store_memory(
|
||||
repository_id=42,
|
||||
essential_query="test query",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
assert "Connection error" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_http_error(self, mock_post, athena_client, sample_contexts):
|
||||
"""Test store_memory handles HTTP errors."""
|
||||
# Setup mock response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.HTTPError):
|
||||
athena_client.store_memory(
|
||||
repository_id=42,
|
||||
essential_query="test query",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_success(self, mock_get, athena_client, sample_query):
|
||||
"""Test successful memory retrieval."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"content": "Authentication code snippet 1"},
|
||||
{"content": "Authentication code snippet 2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Call the method
|
||||
result = athena_client.retrieve_memory(repository_id=42, query=sample_query)
|
||||
|
||||
# Assertions
|
||||
assert len(result) == 2
|
||||
assert result[0]["content"] == "Authentication code snippet 1"
|
||||
mock_get.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/retrieve/42/"
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
# Verify query parameters
|
||||
params = call_args[1]["params"]
|
||||
assert params["essential_query"] == "How to implement authentication?"
|
||||
assert params["extra_requirements"] == "Using JWT tokens"
|
||||
assert params["purpose"] == "Feature implementation"
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_with_optional_fields(self, mock_get, athena_client):
|
||||
"""Test memory retrieval with empty optional query fields."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Create query with empty optional fields
|
||||
query = Query(essential_query="test query", extra_requirements="", purpose="")
|
||||
|
||||
# Call the method
|
||||
result = athena_client.retrieve_memory(repository_id=42, query=query)
|
||||
|
||||
# Verify empty strings are passed correctly
|
||||
call_args = mock_get.call_args
|
||||
params = call_args[1]["params"]
|
||||
assert params["extra_requirements"] == ""
|
||||
assert params["purpose"] == ""
|
||||
assert result == []
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_request_exception(self, mock_get, athena_client, sample_query):
|
||||
"""Test retrieve_memory raises RequestException on failure."""
|
||||
# Setup mock to raise an exception
|
||||
mock_get.side_effect = requests.RequestException("Timeout error")
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.RequestException) as exc_info:
|
||||
athena_client.retrieve_memory(repository_id=42, query=sample_query)
|
||||
|
||||
assert "Timeout error" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_http_error(self, mock_get, athena_client, sample_query):
|
||||
"""Test retrieve_memory handles HTTP errors."""
|
||||
# Setup mock response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.HTTPError):
|
||||
athena_client.retrieve_memory(repository_id=42, query=sample_query)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.delete")
|
||||
def test_delete_repository_memory_success(self, mock_delete, athena_client):
|
||||
"""Test successful repository memory deletion."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success", "deleted_count": 15}
|
||||
mock_delete.return_value = mock_response
|
||||
|
||||
# Call the method
|
||||
result = athena_client.delete_repository_memory(repository_id=42)
|
||||
|
||||
# Assertions
|
||||
assert result == {"status": "success", "deleted_count": 15}
|
||||
mock_delete.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_delete.call_args
|
||||
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/42/"
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.delete")
|
||||
def test_delete_repository_memory_request_exception(self, mock_delete, athena_client):
|
||||
"""Test delete_repository_memory raises RequestException on failure."""
|
||||
# Setup mock to raise an exception
|
||||
mock_delete.side_effect = requests.RequestException("Network error")
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.RequestException) as exc_info:
|
||||
athena_client.delete_repository_memory(repository_id=42)
|
||||
|
||||
assert "Network error" in str(exc_info.value)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.delete")
|
||||
def test_delete_repository_memory_http_error(self, mock_delete, athena_client):
|
||||
"""Test delete_repository_memory handles HTTP errors."""
|
||||
# Setup mock response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("403 Forbidden")
|
||||
mock_delete.return_value = mock_response
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(requests.HTTPError):
|
||||
athena_client.delete_repository_memory(repository_id=42)
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
"""Test suite for module-level convenience functions."""
|
||||
|
||||
@patch("prometheus.utils.memory_utils.athena_client.store_memory")
|
||||
def test_store_memory_function(self, mock_store, sample_contexts):
|
||||
"""Test module-level store_memory function."""
|
||||
from prometheus.utils.memory_utils import store_memory
|
||||
|
||||
# Setup mock
|
||||
mock_store.return_value = {"status": "success"}
|
||||
|
||||
# Call the function
|
||||
result = store_memory(
|
||||
repository_id=123,
|
||||
essential_query="test query",
|
||||
extra_requirements="requirements",
|
||||
purpose="testing",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
# Verify it delegates to the client
|
||||
assert result == {"status": "success"}
|
||||
mock_store.assert_called_once_with(
|
||||
repository_id=123,
|
||||
essential_query="test query",
|
||||
extra_requirements="requirements",
|
||||
purpose="testing",
|
||||
contexts=sample_contexts,
|
||||
)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.athena_client.retrieve_memory")
|
||||
def test_retrieve_memory_function(self, mock_retrieve, sample_query):
|
||||
"""Test module-level retrieve_memory function."""
|
||||
from prometheus.utils.memory_utils import retrieve_memory
|
||||
|
||||
# Setup mock
|
||||
mock_retrieve.return_value = [{"content": "test"}]
|
||||
|
||||
# Call the function
|
||||
result = retrieve_memory(repository_id=123, query=sample_query)
|
||||
|
||||
# Verify it delegates to the client
|
||||
assert result == [{"content": "test"}]
|
||||
mock_retrieve.assert_called_once_with(repository_id=123, query=sample_query)
|
||||
|
||||
@patch("prometheus.utils.memory_utils.athena_client.delete_repository_memory")
|
||||
def test_delete_repository_memory_function(self, mock_delete):
|
||||
"""Test module-level delete_repository_memory function."""
|
||||
from prometheus.utils.memory_utils import delete_repository_memory
|
||||
|
||||
# Setup mock
|
||||
mock_delete.return_value = {"status": "deleted"}
|
||||
|
||||
# Call the function
|
||||
result = delete_repository_memory(repository_id=123)
|
||||
|
||||
# Verify it delegates to the client
|
||||
assert result == {"status": "deleted"}
|
||||
mock_delete.assert_called_once_with(repository_id=123)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and boundary conditions."""
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_with_empty_contexts_list(self, mock_post, athena_client):
|
||||
"""Test storing memory with empty contexts list."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = athena_client.store_memory(
|
||||
repository_id=1,
|
||||
essential_query="test",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
contexts=[],
|
||||
)
|
||||
|
||||
# Verify empty list is handled
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert payload["contexts"] == []
|
||||
assert result == {"status": "success"}
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.get")
|
||||
def test_retrieve_memory_empty_results(self, mock_get, athena_client):
|
||||
"""Test retrieving memory when no results are found."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
query = Query(
|
||||
essential_query="nonexistent query",
|
||||
extra_requirements="",
|
||||
purpose="",
|
||||
)
|
||||
|
||||
result = athena_client.retrieve_memory(repository_id=999, query=query)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_context_serialization(self, sample_contexts):
|
||||
"""Test that Context objects are properly serialized."""
|
||||
context = sample_contexts[0]
|
||||
serialized = context.model_dump()
|
||||
|
||||
assert serialized["relative_path"] == "src/main.py"
|
||||
assert serialized["content"] == "def main():\n print('Hello')"
|
||||
assert serialized["start_line_number"] == 1
|
||||
assert serialized["end_line_number"] == 2
|
||||
|
||||
@patch("prometheus.utils.memory_utils.requests.post")
|
||||
def test_store_memory_with_special_characters(self, mock_post, athena_client):
|
||||
"""Test storing memory with special characters in content."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
special_context = Context(
|
||||
relative_path="test/file.py",
|
||||
content="def test():\n # Special chars: @#$%^&*(){}[]|\\<>?~`",
|
||||
start_line_number=1,
|
||||
end_line_number=2,
|
||||
)
|
||||
|
||||
result = athena_client.store_memory(
|
||||
repository_id=1,
|
||||
essential_query="test with special chars: @#$%",
|
||||
extra_requirements="requirements with 中文",
|
||||
purpose="purpose with émojis 🚀",
|
||||
contexts=[special_context],
|
||||
)
|
||||
|
||||
assert result == {"status": "success"}
|
||||
# Verify special characters are preserved in the payload
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert "@#$%" in payload["query"]["essential_query"]
|
||||
assert "中文" in payload["query"]["extra_requirements"]
|
||||
assert "🚀" in payload["query"]["purpose"]
|
||||
86
prometheus/tests/utils/test_neo4j_util.py
Normal file
86
prometheus/tests/utils/test_neo4j_util.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.utils.knowledge_graph_utils import (
|
||||
EMPTY_DATA_MESSAGE,
|
||||
format_knowledge_graph_data,
|
||||
)
|
||||
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, data_list):
|
||||
self._data = data_list
|
||||
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
|
||||
class MockSession:
|
||||
def __init__(self):
|
||||
self.execute_read = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class MockDriver:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def session(self):
|
||||
return self._session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_neo4j_driver():
|
||||
session = MockSession()
|
||||
driver = MockDriver(session)
|
||||
return driver, session
|
||||
|
||||
|
||||
def test_format_neo4j_data_single_row():
|
||||
data = [{"name": "John", "age": 30}]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\nage: 30\nname: John"
|
||||
|
||||
assert formatted == expected
|
||||
|
||||
|
||||
def test_format_neo4j_result_multiple_rows():
|
||||
data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\nage: 25\nname: Jane"
|
||||
|
||||
assert formatted == expected
|
||||
|
||||
|
||||
def test_format_neo4j_result_empty():
|
||||
data = []
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
assert formatted == EMPTY_DATA_MESSAGE
|
||||
|
||||
|
||||
def test_format_neo4j_result_different_keys():
|
||||
data = [{"name": "John", "age": 30}, {"city": "New York", "country": "USA"}]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\ncity: New York\ncountry: USA"
|
||||
|
||||
assert formatted == expected
|
||||
|
||||
|
||||
def test_format_neo4j_result_complex_values():
|
||||
data = [
|
||||
{"numbers": [1, 2, 3], "metadata": {"type": "user", "active": True}, "date": "2024-01-01"}
|
||||
]
|
||||
|
||||
formatted = format_knowledge_graph_data(data)
|
||||
expected = "Result 1:\ndate: 2024-01-01\nmetadata: {'type': 'user', 'active': True}\nnumbers: [1, 2, 3]"
|
||||
|
||||
assert formatted == expected
|
||||
120
prometheus/tests/utils/test_patch_util.py
Normal file
120
prometheus/tests/utils/test_patch_util.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from pathlib import Path
|
||||
|
||||
from prometheus.utils.patch_util import get_updated_files
|
||||
|
||||
|
||||
def test_get_updated_files_empty_diff():
|
||||
diff = ""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 0
|
||||
assert len(modified) == 0
|
||||
assert len(removed) == 0
|
||||
|
||||
|
||||
def test_get_updated_files_added_only():
|
||||
diff = """
|
||||
diff --git a/new_file.txt b/new_file.txt
|
||||
new file mode 100644
|
||||
index 0000000..1234567
|
||||
--- /dev/null
|
||||
+++ b/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+New content
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 1
|
||||
assert len(modified) == 0
|
||||
assert len(removed) == 0
|
||||
assert added[0] == Path("new_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_modified_only():
|
||||
diff = """
|
||||
diff --git a/modified_file.txt b/modified_file.txt
|
||||
index 1234567..89abcdef
|
||||
--- a/modified_file.txt
|
||||
+++ b/modified_file.txt
|
||||
@@ -1 +1 @@
|
||||
-Old content
|
||||
+Modified content
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 0
|
||||
assert len(modified) == 1
|
||||
assert len(removed) == 0
|
||||
assert modified[0] == Path("modified_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_removed_only():
|
||||
diff = """
|
||||
diff --git a/removed_file.txt b/removed_file.txt
|
||||
deleted file mode 100644
|
||||
index 1234567..0000000
|
||||
--- a/removed_file.txt
|
||||
+++ /dev/null
|
||||
@@ -1 +0,0 @@
|
||||
-Content to be removed
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 0
|
||||
assert len(modified) == 0
|
||||
assert len(removed) == 1
|
||||
assert removed[0] == Path("removed_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_multiple_changes():
|
||||
diff = """
|
||||
diff --git a/new_file.txt b/new_file.txt
|
||||
new file mode 100644
|
||||
index 0000000..1234567
|
||||
--- /dev/null
|
||||
+++ b/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+New content
|
||||
diff --git a/modified_file.txt b/modified_file.txt
|
||||
index 1234567..89abcdef
|
||||
--- a/modified_file.txt
|
||||
+++ b/modified_file.txt
|
||||
@@ -1 +1 @@
|
||||
-Old content
|
||||
+Modified content
|
||||
diff --git a/removed_file.txt b/removed_file.txt
|
||||
deleted file mode 100644
|
||||
index 1234567..0000000
|
||||
--- a/removed_file.txt
|
||||
+++ /dev/null
|
||||
@@ -1 +0,0 @@
|
||||
-Content to be removed
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 1
|
||||
assert len(modified) == 1
|
||||
assert len(removed) == 1
|
||||
assert added[0] == Path("new_file.txt")
|
||||
assert modified[0] == Path("modified_file.txt")
|
||||
assert removed[0] == Path("removed_file.txt")
|
||||
|
||||
|
||||
def test_get_updated_files_with_subfolders():
|
||||
diff = """
|
||||
diff --git a/folder1/new_file.txt b/folder1/new_file.txt
|
||||
new file mode 100644
|
||||
index 0000000..1234567
|
||||
--- /dev/null
|
||||
+++ b/folder1/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+New content
|
||||
diff --git a/folder2/subfolder/modified_file.txt b/folder2/subfolder/modified_file.txt
|
||||
index 1234567..89abcdef
|
||||
--- a/folder2/subfolder/modified_file.txt
|
||||
+++ b/folder2/subfolder/modified_file.txt
|
||||
@@ -1 +1 @@
|
||||
-Old content
|
||||
+Modified content
|
||||
"""
|
||||
added, modified, removed = get_updated_files(diff)
|
||||
assert len(added) == 1
|
||||
assert len(modified) == 1
|
||||
assert len(removed) == 0
|
||||
assert added[0] == Path("folder1/new_file.txt")
|
||||
assert modified[0] == Path("folder2/subfolder/modified_file.txt")
|
||||
21
prometheus/tests/utils/test_str_util.py
Normal file
21
prometheus/tests/utils/test_str_util.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from prometheus.utils.str_util import (
|
||||
pre_append_line_numbers,
|
||||
)
|
||||
|
||||
|
||||
def test_single_line():
|
||||
text = "Hello world"
|
||||
result = pre_append_line_numbers(text, start_line=1)
|
||||
assert result == "1. Hello world"
|
||||
|
||||
|
||||
def test_multiple_lines():
|
||||
text = "First line\nSecond line\nThird line"
|
||||
result = pre_append_line_numbers(text, start_line=1)
|
||||
assert result == "1. First line\n2. Second line\n3. Third line"
|
||||
|
||||
|
||||
def test_empty_string():
|
||||
text = ""
|
||||
result = pre_append_line_numbers(text, start_line=1)
|
||||
assert result == ""
|
||||
Reference in New Issue
Block a user