feat: Add intelligent auto-router and enhanced integrations

- Add intelligent-router.sh hook for automatic agent routing
- Add AUTO-TRIGGER-SUMMARY.md documentation
- Add FINAL-INTEGRATION-SUMMARY.md documentation
- Complete Prometheus integration (6 commands + 4 tools)
- Complete Dexto integration (12 commands + 5 tools)
- Enhanced Ralph with access to all agents
- Fix /clawd command (removed disable-model-invocation)
- Update hooks.json to v5 with intelligent routing
- 291 total skills now available
- All 21 commands with automatic routing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
admin
2026-01-28 00:27:56 +04:00
Unverified
parent 3b128ba3bd
commit b52318eeae
1724 changed files with 351216 additions and 0 deletions

View File

View File

View File

View File

@@ -0,0 +1,57 @@
from unittest import mock
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from prometheus.app.api.routes import auth
from prometheus.app.exception_handler import register_exception_handlers
app = FastAPI()
register_exception_handlers(app)
app.include_router(auth.router, prefix="/auth", tags=["auth"])
client = TestClient(app)
@pytest.fixture
def mock_service():
service = mock.MagicMock()
app.state.service = service
yield service
def test_login(mock_service):
mock_service["user_service"].login = AsyncMock(return_value="your_access_token")
response = client.post(
"/auth/login",
json={
"username": "testuser",
"email": "test@gmail.com",
"password": "passwordpassword",
},
)
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": {"access_token": "your_access_token"},
}
def test_register(mock_service):
mock_service["invitation_code_service"].check_invitation_code = AsyncMock(return_value=True)
mock_service["user_service"].create_user = AsyncMock(return_value=None)
mock_service["invitation_code_service"].mark_code_as_used = AsyncMock(return_value=None)
response = client.post(
"/auth/register",
json={
"username": "testuser",
"email": "test@gmail.com",
"password": "passwordpassword",
"invitation_code": "f23ee204-ff33-401d-8291-1f128d0db08a",
},
)
assert response.status_code == 200
assert response.json() == {"code": 200, "message": "User registered successfully", "data": None}

View File

@@ -0,0 +1,65 @@
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from prometheus.app.api.routes.github import router
# Create test app
app = FastAPI()
app.include_router(router, prefix="/github")
client = TestClient(app)
@pytest.fixture
def mock_issue_data():
"""Fixture for mock issue data."""
return {
"number": 123,
"title": "Test Issue",
"body": "This is a test issue body",
"state": "open",
"html_url": "https://github.com/owner/repo/issues/123",
"comments": [
{"username": "user1", "comment": "First comment"},
{"username": "user2", "comment": "Second comment"},
],
}
def test_get_github_issue_success(mock_issue_data):
"""Test successful retrieval of GitHub issue through the API endpoint."""
with patch("prometheus.app.api.routes.github.get_github_issue") as mock_get_issue:
# Configure the mock
mock_get_issue.return_value = mock_issue_data
# Make the request
response = client.get(
"/github/issue/",
params={"repo": "owner/repo", "issue_number": 123, "github_token": "test_token"},
)
# Assert response status
assert response.status_code == 200
# Parse response
response_data = response.json()
# Assert response structure
assert "data" in response_data
assert "message" in response_data
assert "code" in response_data
# Assert data content
data = response_data["data"]
assert data["number"] == 123
assert data["title"] == "Test Issue"
assert data["body"] == "This is a test issue body"
assert data["state"] == "open"
assert len(data["comments"]) == 2
assert data["comments"][0]["username"] == "user1"
# Verify the function was called with correct parameters
mock_get_issue.assert_called_once_with("owner/repo", 123, "test_token")

View File

@@ -0,0 +1,94 @@
import datetime
from unittest import mock
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from prometheus.app.api.routes import invitation_code
from prometheus.app.entity.invitation_code import InvitationCode
from prometheus.app.exception_handler import register_exception_handlers
app = FastAPI()
register_exception_handlers(app)
app.include_router(invitation_code.router, prefix="/invitation-code", tags=["invitation_code"])
@app.middleware("mock_jwt_middleware")
async def add_user_id(request: Request, call_next):
request.state.user_id = 1 # Set user_id to 1 for testing purposes
response = await call_next(request)
return response
client = TestClient(app)
@pytest.fixture
def mock_service():
service = mock.MagicMock()
app.state.service = service
yield service
def test_create_invitation_code(mock_service):
# Mock the return value of create_invitation_code
mock_service["invitation_code_service"].create_invitation_code = AsyncMock(
return_value=InvitationCode(
id=1,
code="testcode",
is_used=False,
expiration_time=datetime.datetime(
year=2025, month=1, day=1, hour=0, minute=0, second=0
),
)
)
mock_service["user_service"].is_admin = AsyncMock(return_value=True)
# Test the creation endpoint
response = client.post("invitation-code/create/")
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": {
"id": 1,
"code": "testcode",
"is_used": False,
"expiration_time": "2025-01-01T00:00:00",
},
}
def test_list(mock_service):
# Mock user as admin and return a list of invitation codes
mock_service["invitation_code_service"].list_invitation_codes = AsyncMock(
return_value=[
InvitationCode(
id=1,
code="testcode",
is_used=False,
expiration_time=datetime.datetime(
year=2025, month=1, day=1, hour=0, minute=0, second=0
),
)
]
)
mock_service["user_service"].is_admin = AsyncMock(return_value=True)
# Test the list endpoint
response = client.get("invitation-code/list/")
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": [
{
"id": 1,
"code": "testcode",
"is_used": False,
"expiration_time": "2025-01-01T00:00:00",
}
],
}

View File

@@ -0,0 +1,174 @@
from unittest import mock
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from prometheus.app.api.routes import issue
from prometheus.app.entity.repository import Repository
from prometheus.app.exception_handler import register_exception_handlers
from prometheus.lang_graph.graphs.issue_state import IssueType
app = FastAPI()
register_exception_handlers(app)
app.include_router(issue.router, prefix="/issue", tags=["issue"])
client = TestClient(app)
@pytest.fixture
def mock_service():
service = mock.MagicMock()
app.state.service = service
yield service
def test_answer_issue(mock_service):
mock_service["repository_service"].get_repository_by_id = AsyncMock(
return_value=Repository(
id=1,
url="https://github.com/fake/repo.git",
commit_id=None,
playground_path="/path/to/playground",
kg_root_node_id=0,
user_id=None,
kg_max_ast_depth=100,
kg_chunk_size=1000,
kg_chunk_overlap=100,
)
)
mock_service["knowledge_graph_service"].get_knowledge_graph = AsyncMock(
return_value=mock.MagicMock()
)
mock_service["repository_service"].update_repository_status = AsyncMock(return_value=None)
mock_service["issue_service"].answer_issue.return_value = (
"test patch", # patch
True, # passed_reproducing_test
True, # passed_regression_test
True, # passed_existing_test
"Issue fixed", # issue_response
IssueType.BUG, # issue_type
)
response = client.post(
"/issue/answer/",
json={
"repository_id": 1,
"issue_title": "Test Issue",
"issue_body": "Test description",
},
)
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": {
"patch": "test patch",
"passed_reproducing_test": True,
"passed_regression_test": True,
"passed_existing_test": True,
"issue_response": "Issue fixed",
"issue_type": "bug",
},
}
def test_answer_issue_no_repository(mock_service):
mock_service["repository_service"].get_repository_by_id = AsyncMock(return_value=None)
response = client.post(
"/issue/answer/",
json={
"repository_id": 1,
"issue_title": "Test Issue",
"issue_body": "Test description",
},
)
assert response.status_code == 404
def test_answer_issue_invalid_container_config(mock_service):
mock_service["repository_service"].get_repository_by_id = AsyncMock(
return_value=Repository(
id=1,
url="https://github.com/fake/repo.git",
commit_id=None,
playground_path="/path/to/playground",
kg_root_node_id=0,
user_id=None,
kg_max_ast_depth=100,
kg_chunk_size=1000,
kg_chunk_overlap=100,
)
)
response = client.post(
"/issue/answer/",
json={
"repository_id": 1,
"issue_title": "Test Issue",
"issue_body": "Test description",
"dockerfile_content": "FROM python:3.11",
"workdir": None,
},
)
assert response.status_code == 400
def test_answer_issue_with_container(mock_service):
mock_service["repository_service"].get_repository_by_id = AsyncMock(
return_value=Repository(
id=1,
url="https://github.com/fake/repo.git",
commit_id=None,
playground_path="/path/to/playground",
kg_root_node_id=0,
user_id=None,
kg_max_ast_depth=100,
kg_chunk_size=1000,
kg_chunk_overlap=100,
)
)
mock_service["issue_service"].answer_issue.return_value = (
"test patch",
True,
True,
True,
"Issue fixed",
IssueType.BUG,
)
mock_service["knowledge_graph_service"].get_knowledge_graph = AsyncMock(
return_value=mock.MagicMock()
)
mock_service["repository_service"].update_repository_status = AsyncMock(return_value=None)
test_payload = {
"repository_id": 1,
"issue_title": "Test Issue",
"issue_body": "Test description",
"dockerfile_content": "FROM python:3.11",
"run_reproduce_test": True,
"workdir": "/app",
"build_commands": ["pip install -r requirements.txt"],
"test_commands": ["pytest ."],
}
response = client.post("/issue/answer/", json=test_payload)
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": {
"patch": "test patch",
"passed_reproducing_test": True,
"passed_regression_test": True,
"passed_existing_test": True,
"issue_response": "Issue fixed",
"issue_type": "bug",
},
}

View File

@@ -0,0 +1,168 @@
from unittest import mock
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from prometheus.app.api.routes import repository
from prometheus.app.entity.repository import Repository
from prometheus.app.exception_handler import register_exception_handlers
app = FastAPI()
register_exception_handlers(app)
app.include_router(repository.router, prefix="/repository", tags=["repository"])
client = TestClient(app)
@pytest.fixture
def mock_service():
service = mock.MagicMock()
app.state.service = service
yield service
def test_upload_repository(mock_service):
mock_service["repository_service"].clone_github_repo = AsyncMock(return_value="/mock/path")
mock_service["repository_service"].get_repository_by_url_and_commit_id = AsyncMock(
return_value=None
)
mock_service["repository_service"].create_new_repository = AsyncMock(return_value=1)
mock_service["knowledge_graph_service"].build_and_save_knowledge_graph = AsyncMock(
return_value=0
)
response = client.post(
"/repository/upload",
json={
"github_token": "mock_token",
"https_url": "https://github.com/Pantheon-temple/Prometheus",
},
)
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": {"repository_id": 1},
}
def test_upload_repository_at_commit(mock_service):
mock_service["repository_service"].clone_github_repo = AsyncMock(return_value="/mock/path")
mock_service["repository_service"].get_repository_by_url_and_commit_id = AsyncMock(
return_value=None
)
mock_service["repository_service"].create_new_repository = AsyncMock(return_value=1)
mock_service["knowledge_graph_service"].build_and_save_knowledge_graph = AsyncMock(
return_value=0
)
response = client.post(
"/repository/upload/",
json={
"github_token": "mock_token",
"https_url": "https://github.com/Pantheon-temple/Prometheus",
"commit_id": "0c554293648a8705769fa53ec896ae24da75f4fc",
},
)
assert response.status_code == 200
def test_create_branch_and_push(mock_service):
# Mock git_repo
git_repo_mock = MagicMock()
git_repo_mock.create_and_push_branch = AsyncMock(return_value=None)
# Let repository_service.get_repository return the mocked git_repo
mock_service["repository_service"].get_repository.return_value = git_repo_mock
mock_service["repository_service"].get_repository_by_id = AsyncMock(
return_value=Repository(
id=1,
url="https://github.com/fake/repo.git",
commit_id=None,
playground_path="/path/to/playground",
kg_root_node_id=0,
user_id=None,
kg_max_ast_depth=100,
kg_chunk_size=1000,
kg_chunk_overlap=100,
)
)
response = client.post(
"/repository/create-branch-and-push/",
json={
"repository_id": 1,
"branch_name": "new_branch",
"commit_message": "Initial commit on new branch",
"patch": "mock_patch_content",
},
)
assert response.status_code == 200
@mock.patch("prometheus.app.api.routes.repository.delete_repository_memory")
def test_delete(mock_delete_memory, mock_service):
# Mock the delete_repository_memory to return success
mock_delete_memory.return_value = {"code": 200, "message": "success", "data": None}
mock_service["repository_service"].get_repository_by_id = AsyncMock(
return_value=Repository(
id=1,
url="https://github.com/fake/repo.git",
commit_id=None,
playground_path="/path/to/playground",
kg_root_node_id=0,
user_id=None,
kg_max_ast_depth=100,
kg_chunk_size=1000,
kg_chunk_overlap=100,
)
)
mock_service["knowledge_graph_service"].clear_kg = AsyncMock(return_value=None)
mock_service["repository_service"].clean_repository.return_value = None
mock_service["repository_service"].delete_repository = AsyncMock(return_value=None)
response = client.delete(
"repository/delete",
params={
"repository_id": 1,
},
)
assert response.status_code == 200
def test_list(mock_service):
mock_service["repository_service"].get_all_repositories = AsyncMock(
return_value=[
Repository(
id=1,
url="https://github.com/fake/repo.git",
commit_id=None,
playground_path="/path/to/playground",
kg_root_node_id=0,
user_id=None,
kg_max_ast_depth=100,
kg_chunk_size=1000,
kg_chunk_overlap=100,
)
]
)
response = client.get("repository/list/")
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": [
{
"id": 1,
"url": "https://github.com/fake/repo.git",
"commit_id": None,
"is_working": False,
"user_id": None,
"kg_max_ast_depth": 100,
"kg_chunk_size": 1000,
"kg_chunk_overlap": 100,
}
],
}

View File

@@ -0,0 +1,82 @@
from unittest import mock
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from prometheus.app.api.routes import user
from prometheus.app.entity.user import User
from prometheus.app.exception_handler import register_exception_handlers
app = FastAPI()
register_exception_handlers(app)
app.include_router(user.router, prefix="/user", tags=["user"])
@app.middleware("mock_jwt_middleware")
async def add_user_id(request: Request, call_next):
request.state.user_id = 1 # Set user_id to 1 for testing purposes
response = await call_next(request)
return response
client = TestClient(app)
@pytest.fixture
def mock_service():
service = mock.MagicMock()
app.state.service = service
yield service
def test_list(mock_service):
# Mock user as admin and return a list of users
mock_service["user_service"].list_users = AsyncMock(
return_value=[
User(
id=1,
username="testuser",
email="test@gmail.com",
password_hash="hashedpassword",
github_token="ghp_1234567890abcdef1234567890abcdef1234",
issue_credit=10,
is_superuser=False,
)
]
)
mock_service["user_service"].is_admin = AsyncMock(return_value=True)
# Test the list endpoint
response = client.get("user/list/")
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": [
{
"id": 1,
"username": "testuser",
"email": "test@gmail.com",
"issue_credit": 10,
"is_superuser": False,
}
],
}
def test_set_github_token(mock_service):
# Mock user as admin and return a list of users
mock_service["user_service"].set_github_token = AsyncMock(return_value=None)
# Test the list endpoint
response = client.put(
"user/set-github-token/", json={"github_token": "ghp_1234567890abcdef1234567890abcdef1234"}
)
assert response.status_code == 200
assert response.json() == {
"code": 200,
"message": "success",
"data": None,
}

View File

@@ -0,0 +1,138 @@
import pytest
from fastapi import FastAPI, Request, Response
from fastapi.testclient import TestClient
from prometheus.app.middlewares.jwt_middleware import JWTMiddleware
from prometheus.exceptions.jwt_exception import JWTException
from prometheus.utils.jwt_utils import JWTUtils
@pytest.fixture
def app():
"""
Create a FastAPI app with JWTMiddleware installed.
We mark certain (method, path) pairs as login-required.
"""
app = FastAPI()
# Only these routes require login:
login_required_routes = {
("GET", "/protected"),
("OPTIONS", "/protected"), # include options here if you want middleware to check it
("GET", "/me"),
}
app.add_middleware(JWTMiddleware, login_required_routes=login_required_routes)
@app.get("/public")
def public():
return {"ok": True, "route": "public"}
@app.get("/protected")
def protected(request: Request):
# Return back the user_id the middleware stores on request.state
return {
"ok": True,
"route": "protected",
"user_id": getattr(request.state, "user_id", None),
}
@app.get("/me")
def me(request: Request):
return {"user_id": getattr(request.state, "user_id", None)}
# Explicit OPTIONS route to ensure 200/204 so we can assert behavior
@app.options("/protected")
def options_protected():
return Response(status_code=204)
return app
@pytest.fixture
def client(app):
return TestClient(app)
def test_non_protected_route_bypasses_auth(client):
"""
Requests to routes not listed in login_required_routes must bypass JWT check.
"""
resp = client.get("/public")
assert resp.status_code == 200
assert resp.json()["route"] == "public"
def test_missing_authorization_returns_401_on_protected(client):
"""
Missing Authorization header on a protected endpoint should return 401.
"""
resp = client.get("/protected")
assert resp.status_code == 401
body = resp.json()
assert body["code"] == 401
assert "Valid JWT Token is missing" in body["message"]
def test_wrong_scheme_returns_401_on_protected(client):
"""
Wrong Authorization scheme (not Bearer) should return 401 on protected endpoint.
"""
resp = client.get("/protected", headers={"Authorization": "Token abc.def.ghi"})
assert resp.status_code == 401
body = resp.json()
assert body["code"] == 401
assert "Valid JWT Token is missing" in body["message"]
def test_invalid_token_raises_and_returns_error(client, monkeypatch):
"""
If JWTUtils.decode_token raises JWTException, middleware should map it to the response.
"""
def fake_decode(_self, _: str):
raise JWTException(code=403, message="Invalid or expired token")
# Patch the method on the class; middleware instantiates JWTUtils() internally
monkeypatch.setattr(JWTUtils, "decode_token", fake_decode, raising=True)
resp = client.get("/protected", headers={"Authorization": "Bearer bad.token"})
assert resp.status_code == 403
body = resp.json()
assert body["code"] == 403
assert body["message"] == "Invalid or expired token"
def test_valid_token_sets_user_id_and_passes(client, monkeypatch):
"""
With a valid token, request should pass and user_id should be present on request.state.
"""
def fake_decode(_self, _: str):
# Return payload with user_id as middleware expects
return {"user_id": 123}
monkeypatch.setattr(JWTUtils, "decode_token", fake_decode, raising=True)
resp = client.get("/protected", headers={"Authorization": "Bearer good.token"})
assert resp.status_code == 200
body = resp.json()
assert body["ok"] is True
assert body["user_id"] == 123
def test_options_request_passes_through(client, monkeypatch):
"""
OPTIONS preflight should be allowed through without requiring a valid token.
The middleware explicitly bypasses OPTIONS before checking Authorization.
"""
# Even if decode_token would fail, OPTIONS should not trigger it.
def boom(_self, _: str):
raise AssertionError("decode_token should not be called for OPTIONS")
monkeypatch.setattr(JWTUtils, "decode_token", boom, raising=True)
resp = client.options("/protected")
# Our route returns 204; any 2xx is acceptable depending on your route
assert resp.status_code in (200, 204)

View File

@@ -0,0 +1,17 @@
import pytest
from prometheus.app.services.database_service import DatabaseService
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
@pytest.mark.slow
async def test_database_service(postgres_container_fixture): # noqa: F811
url = postgres_container_fixture.get_connection_url()
database_service = DatabaseService(url)
assert database_service.engine is not None
try:
await database_service.start()
await database_service.close()
except Exception as e:
pytest.fail(f"Connection verification failed: {e}")

View File

@@ -0,0 +1,122 @@
from datetime import datetime, timedelta, timezone
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from prometheus.app.entity.invitation_code import InvitationCode
from prometheus.app.services.database_service import DatabaseService
from prometheus.app.services.invitation_code_service import InvitationCodeService
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
@pytest.fixture
async def mock_database_service(postgres_container_fixture): # noqa: F811
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
service = DatabaseService(postgres_container_fixture.get_connection_url())
await service.start()
yield service
await service.close()
@pytest.fixture
def service(mock_database_service):
"""Fixture: construct an InvitationCodeService with the database service."""
return InvitationCodeService(database_service=mock_database_service)
async def _insert_code(
session: AsyncSession, code: str, is_used: bool = False, expires_in_seconds: int = 3600
) -> InvitationCode:
"""Helper: insert a single InvitationCode with given state and expiration."""
obj = InvitationCode(
code=code,
is_used=is_used,
expiration_time=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
)
session.add(obj)
await session.commit()
await session.refresh(obj)
return obj
async def test_create_invitation_code(service):
"""Test that create_invitation_code correctly generates and returns an InvitationCode."""
invitation_code = await service.create_invitation_code()
# Verify the returned object is an InvitationCode instance
assert isinstance(invitation_code, InvitationCode)
assert isinstance(invitation_code.code, str)
assert len(invitation_code.code) == 36 # uuid4 string length
assert invitation_code.id is not None
# Verify the object is persisted in the database
async with AsyncSession(service.engine) as session:
db_obj = await session.get(InvitationCode, invitation_code.id)
assert db_obj is not None
assert db_obj.code == invitation_code.code
async def test_list_invitation_codes(service):
"""Test that list_invitation_codes returns all stored invitation codes."""
# Insert two invitation codes first
code1 = await service.create_invitation_code()
code2 = await service.create_invitation_code()
codes = await service.list_invitation_codes()
# Verify length
assert len(codes) >= 2
# Verify both created codes are included
all_codes = [c.code for c in codes]
assert code1.code in all_codes
assert code2.code in all_codes
async def test_check_invitation_code_returns_false_when_not_exists(service):
"""check_invitation_code should return False if the code does not exist."""
ok = await service.check_invitation_code("non-existent-code")
assert ok is False
async def test_check_invitation_code_returns_false_when_used(service):
"""check_invitation_code should return False if the code is already used."""
async with AsyncSession(service.engine) as session:
await _insert_code(session, "used-code", is_used=True, expires_in_seconds=3600)
ok = await service.check_invitation_code("used-code")
assert ok is False
async def test_check_invitation_code_returns_false_when_expired(service):
"""check_invitation_code should return False if the code is expired."""
async with AsyncSession(service.engine) as session:
# Negative expires_in_seconds makes it expire in the past
await _insert_code(session, "expired-code", is_used=False, expires_in_seconds=-60)
ok = await service.check_invitation_code("expired-code")
assert ok is False
async def test_check_invitation_code_returns_true_when_valid(service):
"""check_invitation_code should return True if the code exists, not used, and not expired."""
async with AsyncSession(service.engine) as session:
await _insert_code(session, "valid-code", is_used=False, expires_in_seconds=3600)
ok = await service.check_invitation_code("valid-code")
assert ok is True
async def test_mark_code_as_used_persists_state(service):
"""mark_code_as_used should set 'used' to True and persist to DB."""
async with AsyncSession(service.engine) as session:
created = await _insert_code(session, "to-use", is_used=False, expires_in_seconds=3600)
created_id = created.id
# Act
await service.mark_code_as_used("to-use")
# Assert persisted state
async with AsyncSession(service.engine) as session:
refreshed = await session.get(InvitationCode, created_id)
assert refreshed is not None
assert refreshed.is_used is True

View File

@@ -0,0 +1,157 @@
from unittest.mock import Mock, create_autospec
import pytest
from prometheus.app.services.issue_service import IssueService
from prometheus.app.services.llm_service import LLMService
from prometheus.app.services.repository_service import RepositoryService
from prometheus.git.git_repository import GitRepository
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.graphs.issue_state import IssueType
@pytest.fixture
def mock_llm_service():
service = create_autospec(LLMService, instance=True)
service.advanced_model = "gpt-4"
service.base_model = "gpt-3.5-turbo"
return service
@pytest.fixture
def mock_repository_service():
service = create_autospec(RepositoryService, instance=True)
return service
@pytest.fixture
def issue_service(mock_llm_service, mock_repository_service):
return IssueService(
llm_service=mock_llm_service,
working_directory="/tmp/working_dir/",
logging_level="DEBUG",
)
async def test_answer_issue_with_general_container(issue_service, monkeypatch):
# Setup
mock_issue_graph = Mock()
mock_issue_graph_class = Mock(return_value=mock_issue_graph)
monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class)
mock_container = Mock()
mock_general_container_class = Mock(return_value=mock_container)
monkeypatch.setattr(
"prometheus.app.services.issue_service.GeneralContainer", mock_general_container_class
)
repository = Mock(spec=GitRepository)
repository.get_working_directory.return_value = "mock/working/directory"
knowledge_graph = Mock(spec=KnowledgeGraph)
# Mock output state for a bug type
mock_output_state = {
"issue_type": IssueType.BUG,
"edit_patch": "test_patch",
"passed_reproducing_test": True,
"passed_regression_test": True,
"passed_existing_test": True,
"issue_response": "test_response",
}
mock_issue_graph.invoke.return_value = mock_output_state
# Exercise
result = issue_service.answer_issue(
repository=repository,
knowledge_graph=knowledge_graph,
repository_id=1,
issue_title="Test Issue",
issue_body="Test Body",
issue_comments=[],
issue_type=IssueType.BUG,
run_build=True,
run_regression_test=True,
run_existing_test=True,
run_reproduce_test=True,
number_of_candidate_patch=1,
build_commands=None,
test_commands=None,
)
# Verify
mock_general_container_class.assert_called_once_with(
project_path=repository.get_working_directory(), build_commands=None, test_commands=None
)
mock_issue_graph_class.assert_called_once_with(
advanced_model=issue_service.llm_service.advanced_model,
base_model=issue_service.llm_service.base_model,
kg=knowledge_graph,
git_repo=repository,
container=mock_container,
repository_id=1,
test_commands=None,
)
assert result == ("test_patch", True, True, True, "test_response", IssueType.BUG)
async def test_answer_issue_with_user_defined_container(issue_service, monkeypatch):
# Setup
mock_issue_graph = Mock()
mock_issue_graph_class = Mock(return_value=mock_issue_graph)
monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class)
mock_container = Mock()
mock_user_container_class = Mock(return_value=mock_container)
monkeypatch.setattr(
"prometheus.app.services.issue_service.UserDefinedContainer", mock_user_container_class
)
repository = Mock(spec=GitRepository)
repository.get_working_directory.return_value = "mock/working/directory"
knowledge_graph = Mock(spec=KnowledgeGraph)
# Mock output state for a question type
mock_output_state = {
"issue_type": IssueType.QUESTION,
"edit_patch": None,
"passed_reproducing_test": False,
"passed_regression_test": False,
"passed_existing_test": False,
"issue_response": "test_response",
}
mock_issue_graph.invoke.return_value = mock_output_state
# Exercise
result = issue_service.answer_issue(
repository=repository,
knowledge_graph=knowledge_graph,
repository_id=1,
issue_title="Test Issue",
issue_body="Test Body",
issue_comments=[],
issue_type=IssueType.QUESTION,
run_build=True,
run_regression_test=True,
run_existing_test=True,
run_reproduce_test=True,
number_of_candidate_patch=1,
dockerfile_content="FROM python:3.8",
image_name="test-image",
workdir="/app",
build_commands=["pip install -r requirements.txt"],
test_commands=["pytest"],
)
# Verify
mock_user_container_class.assert_called_once_with(
project_path=repository.get_working_directory(),
workdir="/app",
build_commands=["pip install -r requirements.txt"],
test_commands=["pytest"],
dockerfile_content="FROM python:3.8",
image_name="test-image",
)
assert result == (None, False, False, False, "test_response", IssueType.QUESTION)

View File

@@ -0,0 +1,104 @@
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from prometheus.app.services.knowledge_graph_service import KnowledgeGraphService
from prometheus.app.services.neo4j_service import Neo4jService
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
@pytest.fixture
def mock_neo4j_service():
"""Mock the Neo4jService."""
neo4j_service = MagicMock(Neo4jService)
return neo4j_service
@pytest.fixture
def mock_kg_handler():
"""Mock the KnowledgeGraphHandler."""
kg_handler = MagicMock(KnowledgeGraphHandler)
return kg_handler
@pytest.fixture
def knowledge_graph_service(mock_neo4j_service, mock_kg_handler):
"""Fixture to create KnowledgeGraphService instance."""
mock_neo4j_service.neo4j_driver = MagicMock() # Mocking Neo4j driver
mock_kg_handler.get_new_knowledge_graph_root_node_id = AsyncMock(return_value=123)
mock_kg_handler.write_knowledge_graph = AsyncMock()
knowledge_graph_service = KnowledgeGraphService(
neo4j_service=mock_neo4j_service,
neo4j_batch_size=1000,
max_ast_depth=5,
chunk_size=1000,
chunk_overlap=100,
)
knowledge_graph_service.kg_handler = mock_kg_handler
return knowledge_graph_service
async def test_build_and_save_knowledge_graph(knowledge_graph_service, mock_kg_handler):
"""Test the build_and_save_knowledge_graph method."""
# Given
source_code_path = Path("/mock/path/to/source/code") # Mock path to source code
# Mock KnowledgeGraph and its methods
mock_kg = MagicMock(KnowledgeGraph)
mock_kg.build_graph = AsyncMock(return_value=None) # Mock async method to build graph
mock_kg_handler.get_new_knowledge_graph_root_node_id = AsyncMock(
return_value=123
) # Mock return value
mock_kg_handler.write_knowledge_graph = AsyncMock(return_value=None)
# When
with pytest.raises(Exception):
knowledge_graph_service.kg_handler = mock_kg_handler
result = await knowledge_graph_service.build_and_save_knowledge_graph(source_code_path)
# Then
assert result == 123 # Ensure that the correct root node ID is returned
mock_kg.build_graph.assert_awaited_once_with(
source_code_path
) # Check if build_graph was called correctly
mock_kg_handler.write_knowledge_graph.assert_called_once() # Ensure graph write happened
mock_kg_handler.get_new_knowledge_graph_root_node_id.assert_called_once() # Ensure the root node ID was fetched
async def test_clear_kg(knowledge_graph_service, mock_kg_handler):
"""Test the clear_kg method."""
# Given
root_node_id = 123 # Mock root node ID
# When
await knowledge_graph_service.clear_kg(root_node_id)
# Then
mock_kg_handler.clear_knowledge_graph.assert_called_once_with(root_node_id)
async def test_get_knowledge_graph(knowledge_graph_service, mock_kg_handler):
"""Test the get_knowledge_graph method."""
# Given
root_node_id = 123 # Mock root node ID
max_ast_depth = 5
chunk_size = 1000
chunk_overlap = 100
# Mock KnowledgeGraph
mock_kg = MagicMock(KnowledgeGraph)
mock_kg_handler.read_knowledge_graph = AsyncMock(return_value=mock_kg) # Mock return value
# When
result = await knowledge_graph_service.get_knowledge_graph(
root_node_id, max_ast_depth, chunk_size, chunk_overlap
)
# Then
mock_kg_handler.read_knowledge_graph.assert_called_once_with(
root_node_id, max_ast_depth, chunk_size, chunk_overlap
) # Ensure read_knowledge_graph is called with the correct parameters
assert result == mock_kg # Ensure the correct KnowledgeGraph object is returned

View File

@@ -0,0 +1,125 @@
from unittest.mock import Mock, patch
import pytest
from prometheus.app.services.llm_service import CustomChatOpenAI, LLMService, get_model
@pytest.fixture
def mock_custom_chat_openai():
with patch("prometheus.app.services.llm_service.CustomChatOpenAI") as mock:
yield mock
@pytest.fixture
def mock_chat_anthropic():
with patch("prometheus.app.services.llm_service.ChatAnthropic") as mock:
yield mock
@pytest.fixture
def mock_chat_google():
with patch("prometheus.app.services.llm_service.ChatGoogleGenerativeAI") as mock:
yield mock
def test_llm_service_init(mock_custom_chat_openai, mock_chat_anthropic):
# Setup
mock_gpt_instance = Mock()
mock_claude_instance = Mock()
mock_custom_chat_openai.return_value = mock_gpt_instance
mock_chat_anthropic.return_value = mock_claude_instance
# Exercise
service = LLMService(
advanced_model_name="gpt-4",
base_model_name="claude-2.1",
advanced_model_temperature=0.0,
base_model_temperature=0.0,
openai_format_api_key="openai-key",
openai_format_base_url="https://api.openai.com/v1",
anthropic_api_key="anthropic-key",
)
# Verify
assert service.advanced_model == mock_gpt_instance
assert service.base_model == mock_claude_instance
mock_custom_chat_openai.assert_called_once_with(
model="gpt-4",
api_key="openai-key",
base_url="https://api.openai.com/v1",
temperature=0.0,
max_retries=3,
)
mock_chat_anthropic.assert_called_once_with(
model_name="claude-2.1",
api_key="anthropic-key",
temperature=0.0,
max_retries=3,
)
def test_get_openai_format_model(mock_custom_chat_openai):
# Exercise
get_model(
model_name="openrouter/model",
openai_format_api_key="openrouter-key",
openai_format_base_url="https://openrouter.ai/api/v1",
temperature=0.0,
)
# Verify
mock_custom_chat_openai.assert_called_once_with(
model="openrouter/model",
api_key="openrouter-key",
base_url="https://openrouter.ai/api/v1",
temperature=0.0,
max_retries=3,
)
def test_get_model_claude(mock_chat_anthropic):
# Exercise
get_model(
model_name="claude-2.1",
anthropic_api_key="anthropic-key",
temperature=0.0,
)
# Verify
mock_chat_anthropic.assert_called_once_with(
model_name="claude-2.1",
api_key="anthropic-key",
temperature=0.0,
max_retries=3,
)
def test_get_model_gemini(mock_chat_google):
# Exercise
get_model(
model_name="gemini-pro",
gemini_api_key="gemini-key",
temperature=0.0,
)
# Verify
mock_chat_google.assert_called_once_with(
model="gemini-pro",
api_key="gemini-key",
temperature=0.0,
max_retries=3,
)
def test_custom_chat_openai_bind_tools():
# Setup
model = CustomChatOpenAI(api_key="test-key", max_input_tokens=64000)
mock_tools = [Mock()]
# Exercise
with patch("prometheus.chat_models.custom_chat_openai.ChatOpenAI.bind_tools") as mock_bind:
model.bind_tools(mock_tools)
# Verify
mock_bind.assert_called_once_with(mock_tools, tool_choice=None, parallel_tool_calls=False)

View File

@@ -0,0 +1,19 @@
import pytest
from prometheus.app.services.neo4j_service import Neo4jService
from tests.test_utils.fixtures import neo4j_container_with_kg_fixture # noqa: F401
@pytest.mark.slow
async def test_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811
neo4j_container, kg = neo4j_container_with_kg_fixture
neo4j_service = Neo4jService(
neo4j_container.get_connection_url(), neo4j_container.username, neo4j_container.password
)
assert neo4j_service.neo4j_driver is not None
neo4j_service.start()
try:
await neo4j_service.neo4j_driver.verify_connectivity()
except Exception as e:
pytest.fail(f"Connection verification failed: {e}")
await neo4j_service.close()

View File

@@ -0,0 +1,197 @@
from pathlib import Path
from unittest.mock import create_autospec, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from prometheus.app.entity.repository import Repository
from prometheus.app.services.database_service import DatabaseService
from prometheus.app.services.knowledge_graph_service import KnowledgeGraphService
from prometheus.app.services.repository_service import RepositoryService
from prometheus.git.git_repository import GitRepository
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
@pytest.fixture
def mock_kg_service():
# Mock KnowledgeGraphService; RepositoryService only reads its attributes in other paths
kg_service = create_autospec(KnowledgeGraphService, instance=True)
kg_service.max_ast_depth = 3
kg_service.chunk_size = 1000
kg_service.chunk_overlap = 100
return kg_service
@pytest.fixture
async def mock_database_service(postgres_container_fixture): # noqa: F811
service = DatabaseService(postgres_container_fixture.get_connection_url())
await service.start()
yield service
await service.close()
@pytest.fixture
def mock_git_repository():
repo = create_autospec(GitRepository, instance=True)
repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo")
return repo
@pytest.fixture
def service(mock_kg_service, mock_database_service, monkeypatch):
working_dir = "/test/working/dir"
# Avoid touching the real filesystem when creating the base repo folder
monkeypatch.setattr(Path, "mkdir", lambda *args, **kwargs: None)
return RepositoryService(
kg_service=mock_kg_service,
database_service=mock_database_service, # <-- use the correct fixture here
working_dir=working_dir,
)
async def test_clone_new_github_repo(service, mock_git_repository, monkeypatch):
# Arrange
test_url = "https://github.com/test/repo"
test_commit = "abc123"
test_github_token = "test_token"
expected_path = Path("/test/working/dir/repositories/repo")
# Force get_new_playground_path() to return a deterministic path for assertions
monkeypatch.setattr(service, "get_new_playground_path", lambda: expected_path)
# Patch GitRepository so its constructor returns our mock instance
with patch(
"prometheus.app.services.repository_service.GitRepository",
return_value=mock_git_repository,
) as mock_git_class:
# Act
result_path = await service.clone_github_repo(test_github_token, test_url, test_commit)
# Assert
# GitRepository should be instantiated without args (per current implementation)
mock_git_class.assert_called_once_with()
# Ensure the clone method was invoked with correct parameters
mock_git_repository.from_clone_repository.assert_called_once_with(
test_url, test_github_token, expected_path
)
# Verify the requested commit was checked out
mock_git_repository.checkout_commit.assert_called_once_with(test_commit)
# The returned path should be the working directory of the mocked repo
assert result_path == expected_path
def test_get_new_playground_path(service):
expected_uuid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
expected_path = service.target_directory / expected_uuid
with patch("uuid.uuid4") as mock_uuid:
mock_uuid.return_value.hex = expected_uuid
result = service.get_new_playground_path()
assert result == expected_path
def test_clean_repository_removes_dir_and_parent(service, monkeypatch):
"""
Should call shutil.rmtree on the repository path and remove its parent directory
when the path exists.
"""
repo_path = Path("/tmp/repositories/abc123")
repository = Repository(playground_path=str(repo_path))
# Patch path.exists() to return True
monkeypatch.setattr(Path, "exists", lambda self: self == repo_path)
# Track calls to shutil.rmtree and Path.rmdir
removed = {"rmtree": None, "rmdir": []}
monkeypatch.setattr(
"shutil.rmtree",
lambda target: removed.update(rmtree=target),
)
monkeypatch.setattr(
Path,
"rmdir",
lambda self: removed["rmdir"].append(self),
)
service.clean_repository(repository)
# Assert rmtree called with correct path string
assert removed["rmtree"] == str(repo_path)
# Assert rmdir called on the parent directory
assert repo_path.parent in removed["rmdir"]
def test_clean_repository_skips_when_not_exists(service, monkeypatch):
"""
Should not call rmtree or rmdir when the repository path does not exist.
"""
repo_path = Path("/tmp/repositories/abc123")
repository = Repository(playground_path=str(repo_path))
# Path.exists returns False
monkeypatch.setattr(Path, "exists", lambda self: False)
monkeypatch.setattr("shutil.rmtree", lambda target: pytest.fail("rmtree should not be called"))
monkeypatch.setattr(Path, "rmdir", lambda self: pytest.fail("rmdir should not be called"))
# No exception means pass
service.clean_repository(repository)
def test_get_repository_returns_git_repo_instance(service):
"""
Should create a GitRepository, call from_local_repository with the given path,
and return the GitRepository instance.
"""
test_path = "/some/local/path"
mock_git_repo_instance = create_autospec(GitRepository, instance=True)
# Patch GitRepository() constructor to return our mock instance
with patch(
"prometheus.app.services.repository_service.GitRepository",
return_value=mock_git_repo_instance,
) as mock_git_class:
result = service.get_repository(test_path)
# Verify GitRepository() was called with no args
mock_git_class.assert_called_once_with()
# Verify from_local_repository was called with the correct Path object
mock_git_repo_instance.from_local_repository.assert_called_once_with(Path(test_path))
# Verify the returned object is the same as the mock instance
assert result == mock_git_repo_instance
async def test_create_new_repository(service):
# Exercise
await service.create_new_repository(
url="https://github.com/test/repo",
commit_id="abc123",
playground_path="/tmp/repositories/repo",
user_id=None,
kg_root_node_id=0,
)
# Verify the object is persisted in the database
async with AsyncSession(service.engine) as session:
db_obj = await session.get(Repository, 1)
assert db_obj is not None
assert db_obj.url == "https://github.com/test/repo"
assert db_obj.commit_id == "abc123"
assert db_obj.playground_path == "/tmp/repositories/repo"
assert db_obj.user_id is None
assert db_obj.kg_root_node_id == 0
async def test_get_all_repositories(service):
# Exercise
repos = await service.get_all_repositories()
# Verify
assert len(repos) == 1

View File

@@ -0,0 +1,50 @@
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from prometheus.app.entity.user import User
from prometheus.app.services.database_service import DatabaseService
from prometheus.app.services.user_service import UserService
from tests.test_utils.fixtures import postgres_container_fixture # noqa: F401
@pytest.fixture
async def mock_database_service(postgres_container_fixture): # noqa: F811
service = DatabaseService(postgres_container_fixture.get_connection_url())
await service.start()
yield service
await service.close()
async def test_create_superuser(mock_database_service):
# Exercise
service = UserService(mock_database_service)
await service.create_superuser(
"testuser", "test@gmail.com", "password123", github_token="gh_token"
)
# Verify
async with AsyncSession(service.engine) as session:
user = await session.get(User, 1)
assert user is not None
assert user.username == "testuser"
assert user.email == "test@gmail.com"
assert user.github_token == "gh_token"
async def test_login(mock_database_service):
# Exercise
service = UserService(mock_database_service)
access_token = await service.login("testuser", "test@gmail.com", "password123")
# Verify
assert access_token is not None
async def test_set_github_token(mock_database_service):
# Exercise
service = UserService(mock_database_service)
await service.set_github_token(1, "new_gh_token")
# Verify
async with AsyncSession(service.engine) as session:
user = await session.get(User, 1)
assert user is not None
assert user.github_token == "new_gh_token"

View File

@@ -0,0 +1,28 @@
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
@pytest.fixture
def mock_dependencies():
"""Mock the service dependencies"""
mock_service = MagicMock()
with patch("prometheus.app.dependencies.initialize_services", return_value=mock_service):
yield mock_service
@pytest.fixture
def test_client(mock_dependencies):
"""Create a TestClient instance with mocked settings and dependencies"""
# Import app here to ensure settings are properly mocked
from prometheus.app.main import app
with TestClient(app) as client:
yield client
def test_app_initialization(test_client, mock_dependencies):
"""Test that the app initializes correctly with mocked dependencies"""
assert test_client.app.state.service is not None
assert test_client.app.state.service == mock_dependencies

View File

View File

@@ -0,0 +1,309 @@
import shutil
import tempfile
from pathlib import Path
from unittest.mock import Mock, call, patch
import pytest
from prometheus.docker.base_container import BaseContainer
class TestContainer(BaseContainer):
"""Concrete implementation of BaseContainer for testing."""
def get_dockerfile_content(self) -> str:
return "FROM python:3.9\nWORKDIR /app\nCOPY . /app/"
@pytest.fixture
def temp_project_dir():
# Create a temporary directory with some test files
temp_dir = Path(tempfile.mkdtemp())
test_file = temp_dir / "test.txt"
test_file.write_text("test content")
yield temp_dir
# Cleanup
shutil.rmtree(temp_dir)
@pytest.fixture
def mock_docker_client():
with patch.object(BaseContainer, "client", new_callable=Mock) as mock_client:
yield mock_client
@pytest.fixture
def container(temp_project_dir, mock_docker_client):
container = TestContainer(
project_path=temp_project_dir,
workdir="/app",
build_commands=["pip install -r requirements.txt", "python setup.py build"],
test_commands=["pytest tests/"],
)
container.tag_name = "test_container_tag"
return container
def test_get_dockerfile_content(container):
"""Test that get_dockerfile_content returns expected content"""
dockerfile_content = container.get_dockerfile_content()
assert "FROM python:3.9" in dockerfile_content
assert "WORKDIR /app" in dockerfile_content
assert "COPY . /app/" in dockerfile_content
def test_build_docker_image(container, mock_docker_client):
"""Test building Docker image"""
# Setup mock for api.build to return an iterable of log entries
mock_build_logs = [
{"stream": "Step 1/3 : FROM python:3.9"},
{"stream": "Step 2/3 : WORKDIR /app"},
{"stream": "Step 3/3 : COPY . /app/"},
{"stream": "Successfully built abc123"},
]
mock_docker_client.api.build.return_value = iter(mock_build_logs)
# Execute
container.build_docker_image()
# Verify
assert (container.project_path / "prometheus.Dockerfile").exists()
mock_docker_client.api.build.assert_called_once_with(
path=str(container.project_path),
dockerfile="prometheus.Dockerfile",
tag=container.tag_name,
rm=True,
decode=True,
)
@patch("prometheus.docker.base_container.pexpect.spawn")
def test_start_container(mock_spawn, container, mock_docker_client):
"""Test starting Docker container"""
# Setup mock for pexpect shell
mock_shell = Mock()
mock_spawn.return_value = mock_shell
mock_shell.expect.return_value = 0 # Simulate successful prompt match
# Setup mock for docker client
mock_containers = Mock()
mock_docker_client.containers = mock_containers
mock_container = Mock()
mock_container.id = "test_container_id"
mock_containers.run.return_value = mock_container
# Execute
container.start_container()
# Verify docker container run was called
mock_containers.run.assert_called_once_with(
container.tag_name,
detach=True,
tty=True,
network_mode="host",
environment={"PYTHONPATH": f"{container.workdir}:$PYTHONPATH"},
volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}},
)
# Verify pexpect shell was started
mock_spawn.assert_called_once_with(
f"docker exec -it {mock_container.id} /bin/bash",
encoding="utf-8",
timeout=container.timeout,
)
mock_shell.expect.assert_called()
def test_is_running(container):
"""Test is_running status check"""
# Test when container is None
assert not container.is_running()
# Test when container exists
container.container = Mock()
assert container.is_running()
def test_update_files(container, temp_project_dir):
"""Test updating files in container"""
# Setup
container.container = Mock()
container.execute_command = Mock()
# Create test files
test_file1 = temp_project_dir / "dir1" / "test1.txt"
test_file2 = temp_project_dir / "dir2" / "test2.txt"
test_file1.parent.mkdir(parents=True)
test_file2.parent.mkdir(parents=True)
test_file1.write_text("test1")
test_file2.write_text("test2")
updated_files = [Path("dir1/test1.txt"), Path("dir2/test2.txt")]
removed_files = [Path("dir3/old.txt")]
# Execute
container.update_files(temp_project_dir, updated_files, removed_files)
# Verify
container.execute_command.assert_has_calls(
[call("rm dir3/old.txt"), call("mkdir -p dir1"), call("mkdir -p dir2")]
)
assert container.container.put_archive.called
@patch("prometheus.docker.base_container.pexpect.spawn")
def test_execute_command(mock_spawn, container):
"""Test executing command in container using persistent shell"""
# Setup mock shell
mock_shell = Mock()
mock_spawn.return_value = mock_shell
# Setup container and shell
container.container = Mock()
container.container.id = "test_container_id"
container.shell = mock_shell
mock_shell.isalive.return_value = True
# Mock the shell interactions
mock_shell.match = Mock()
mock_shell.match.group.return_value = "0" # Exit code 0
mock_shell.before = "test command\ncommand output"
# Execute
result = container.execute_command("test command")
# Verify shell interactions
assert mock_shell.sendline.call_count == 2 # Command + marker command
mock_shell.expect.assert_called()
# The result should contain the cleaned output
assert "command output" in result
def test_execute_command_with_mock(container):
"""Test executing command with direct mocking"""
# Setup - directly mock the execute_command method
container.execute_command = Mock(return_value="mocked output")
container.container = Mock()
# Execute
result = container.execute_command("test command")
# Verify
container.execute_command.assert_called_once_with("test command")
assert result == "mocked output"
def test_reset_repository(container):
"""Test container reset repository"""
# Setup - Mock the execute_command method
container.execute_command = Mock(return_value="Command output")
container.container = Mock()
# Execute
container.reset_repository()
# Verify - Check that execute_command was called twice with the correct commands
assert container.execute_command.call_count == 2
expected_calls = [call("git reset --hard"), call("git clean -fd")]
container.execute_command.assert_has_calls(expected_calls, any_order=False)
@patch("prometheus.docker.base_container.pexpect.spawn")
def test_cleanup(mock_spawn, container, mock_docker_client):
"""Test cleanup of container resources"""
# Setup
mock_container = Mock()
container.container = mock_container
# Setup mock shell
mock_shell = Mock()
mock_shell.isalive.return_value = True
container.shell = mock_shell
# Execute
container.cleanup()
# Verify shell cleanup
mock_shell.close.assert_called_once_with(force=True)
# Verify container cleanup
mock_container.stop.assert_called_once_with(timeout=10)
mock_container.remove.assert_called_once_with(force=True)
mock_docker_client.images.remove.assert_called_once_with(container.tag_name, force=True)
assert not container.project_path.exists()
def test_run_build(container):
"""Test that build commands are executed correctly"""
container.execute_command = Mock()
container.execute_command.side_effect = ["Output 1", "Output 2"]
build_output = container.run_build()
# Verify execute_command was called for each build command
assert container.execute_command.call_count == 2
container.execute_command.assert_any_call("pip install -r requirements.txt")
container.execute_command.assert_any_call("python setup.py build")
# Verify output format
expected_output = (
"$ pip install -r requirements.txt\nOutput 1\n$ python setup.py build\nOutput 2\n"
)
assert build_output == expected_output
def test_run_test(container):
"""Test that test commands are executed correctly"""
container.execute_command = Mock()
container.execute_command.return_value = "Test passed"
test_output = container.run_test()
# Verify execute_command was called for the test command
container.execute_command.assert_called_once_with("pytest tests/")
# Verify output format
expected_output = "$ pytest tests/\nTest passed\n"
assert test_output == expected_output
def test_run_build_no_commands(container):
"""Test run_build when no build commands are defined"""
container.build_commands = None
result = container.run_build()
assert result == ""
def test_run_test_no_commands(container):
"""Test run_test when no test commands are defined"""
container.test_commands = None
result = container.run_test()
assert result == ""
@patch("prometheus.docker.base_container.pexpect.spawn")
def test_restart_shell_if_needed(mock_spawn, container):
"""Test shell restart functionality"""
# Setup
mock_shell_dead = Mock()
mock_shell_dead.isalive.return_value = False
mock_shell_new = Mock()
mock_shell_new.expect.return_value = 0
mock_spawn.return_value = mock_shell_new
container.container = Mock()
container.container.id = "test_container_id"
container.shell = mock_shell_dead
# Execute
container._restart_shell_if_needed()
# Verify old shell was closed and new one started
mock_shell_dead.close.assert_called_once_with(force=True)
mock_spawn.assert_called_once()
assert container.shell == mock_shell_new

View File

@@ -0,0 +1,44 @@
import shutil
import tempfile
from pathlib import Path
import pytest
from prometheus.docker.general_container import GeneralContainer
@pytest.fixture
def temp_project_dir():
# Create a temporary directory with some test files
temp_dir = Path(tempfile.mkdtemp())
test_file = temp_dir / "test.txt"
test_file.write_text("test content")
yield temp_dir
# Cleanup
shutil.rmtree(temp_dir)
@pytest.fixture
def container(temp_project_dir):
return GeneralContainer(temp_project_dir)
def test_initialization(container, temp_project_dir):
"""Test that the container is initialized correctly"""
assert isinstance(container.tag_name, str)
assert container.tag_name.startswith("prometheus_general_container_")
assert container.project_path != temp_project_dir
assert (container.project_path / "test.txt").exists()
def test_get_dockerfile_content(container):
dockerfile_content = container.get_dockerfile_content()
assert dockerfile_content
assert "FROM ubuntu:24.04" in dockerfile_content
assert "WORKDIR /app" in dockerfile_content
assert "RUN apt-get update" in dockerfile_content
assert "COPY . /app/" in dockerfile_content

View File

@@ -0,0 +1,51 @@
import shutil
import tempfile
from pathlib import Path
import pytest
from prometheus.docker.user_defined_container import UserDefinedContainer
@pytest.fixture
def temp_project_dir():
# Create a temporary directory with some test files
temp_dir = Path(tempfile.mkdtemp())
test_file = temp_dir / "test.txt"
test_file.write_text("test content")
yield temp_dir
# Cleanup
shutil.rmtree(temp_dir)
@pytest.fixture
def container(temp_project_dir):
return UserDefinedContainer(
temp_project_dir,
"/app",
"FROM python:3.9\nWORKDIR /app\nCOPY . /app/",
None,
["pip install -r requirements.txt", "python setup.py build"],
["pytest tests/"],
)
def test_initialization(container, temp_project_dir):
"""Test that the container is initialized correctly"""
assert isinstance(container.tag_name, str)
assert container.tag_name.startswith("prometheus_user_defined_container_")
assert container.project_path != temp_project_dir
assert (container.project_path / "test.txt").exists()
def test_get_dockerfile_content(container):
dockerfile_content = container.get_dockerfile_content()
assert dockerfile_content
# Check for key elements in the Dockerfile
assert "FROM python:3.9" in dockerfile_content
assert "WORKDIR /app" in dockerfile_content
assert "COPY . /app/" in dockerfile_content

View File

View File

@@ -0,0 +1,164 @@
import sys
from pathlib import Path
from unittest import mock
import pytest
from prometheus.git.git_repository import GitRepository
from tests.test_utils import test_project_paths
from tests.test_utils.fixtures import git_repo_fixture # noqa: F401
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Test fails on Windows because of cptree in git_repo_fixture",
)
@pytest.mark.git
def test_init_with_https_url(git_repo_fixture): # noqa: F811
with mock.patch("git.Repo.clone_from") as mock_clone_from, mock.patch("shutil.rmtree"):
repo = git_repo_fixture
mock_clone_from.return_value = repo
access_token = "access_token"
https_url = "https://github.com/foo/bar.git"
target_directory = test_project_paths.TEST_PROJECT_PATH
git_repo = GitRepository()
git_repo.from_clone_repository(
https_url=https_url, target_directory=target_directory, github_access_token=access_token
)
mock_clone_from.assert_called_once_with(
f"https://{access_token}@github.com/foo/bar.git", target_directory / "bar"
)
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Test fails on Windows because of cptree in git_repo_fixture",
)
@pytest.mark.git
def test_checkout_commit(git_repo_fixture): # noqa: F811
git_repo = GitRepository()
git_repo.from_local_repository(git_repo_fixture.working_dir)
commit_sha = "293551b7bd9572b63018c9ed2bccea0f37726805"
assert git_repo.repo.head.commit.hexsha != commit_sha
git_repo.checkout_commit(commit_sha)
assert git_repo.repo.head.commit.hexsha == commit_sha
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Test fails on Windows because of cptree in git_repo_fixture",
)
@pytest.mark.git
def test_switch_branch(git_repo_fixture): # noqa: F811
git_repo = GitRepository()
git_repo.from_local_repository(git_repo_fixture.working_dir)
branch_name = "dev"
assert git_repo.repo.active_branch.name != branch_name
git_repo.switch_branch(branch_name)
assert git_repo.repo.active_branch.name == branch_name
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Test fails on Windows because of cptree in git_repo_fixture",
)
@pytest.mark.git
def test_get_diff(git_repo_fixture): # noqa: F811
local_path = Path(git_repo_fixture.working_dir).absolute()
test_file = local_path / "test.c"
# Initialize repository
git_repo = GitRepository()
git_repo.from_local_repository(local_path)
# Create a change by modifying test.c
original_content = test_file.read_text()
new_content = "int main() { return 0; }\n"
test_file.write_text(new_content)
# Get diff without exclusions
diff = git_repo.get_diff()
assert diff is not None
expected_diff = """\
diff --git a/test.c b/test.c
index 79a1160..76e8197 100644
--- a/test.c
+++ b/test.c
@@ -1,6 +1 @@
-#include <stdio.h>
-
-int main() {
- printf("Hello world!");
- return 0;
-}
\ No newline at end of file
+int main() { return 0; }
"""
assert diff == expected_diff
# Test with excluded files
diff_with_exclusion = git_repo.get_diff(excluded_files=["test.c"])
assert diff_with_exclusion == ""
# Cleanup - restore original content
test_file.write_text(original_content)
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Test fails on Windows because of cptree in git_repo_fixture",
)
@pytest.mark.git
def test_apply_patch(git_repo_fixture): # noqa: F811
local_path = Path(git_repo_fixture.working_dir).absolute()
# Initialize repository
git_repo = GitRepository()
git_repo.from_local_repository(local_path)
# Apply a patch that modifies test.c
patch = """\
diff --git a/test.c b/test.c
--- a/test.c
+++ b/test.c
@@ -1,6 +1,1 @@
-#include <stdio.h>
-
-int main() {
- printf("Hello world!");
- return 0;
-}
\ No newline at end of file
+int main() { return 0; }
\ No newline at end of file
"""
git_repo.apply_patch(patch)
# Verify the change
test_file = local_path / "test.c"
assert test_file.exists()
assert test_file.read_text() == "int main() { return 0; }"
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Test fails on Windows because of cptree in git_repo_fixture",
)
@pytest.mark.git
def test_remove_repository(git_repo_fixture): # noqa: F811
with mock.patch("shutil.rmtree") as mock_rmtree:
local_path = git_repo_fixture.working_dir
git_repo = GitRepository()
git_repo.from_local_repository(local_path)
assert git_repo.repo is not None
git_repo.remove_repository()
mock_rmtree.assert_called_once_with(local_path)
assert git_repo.repo is None

View File

View File

@@ -0,0 +1,105 @@
from prometheus.graph.file_graph_builder import FileGraphBuilder
from prometheus.graph.graph_types import (
ASTNode,
KnowledgeGraphEdgeType,
KnowledgeGraphNode,
TextNode,
)
from tests.test_utils import test_project_paths
def test_supports_file():
file_graph_builder = FileGraphBuilder(0, 0, 0)
assert file_graph_builder.supports_file(test_project_paths.C_FILE)
assert file_graph_builder.supports_file(test_project_paths.JAVA_FILE)
assert file_graph_builder.supports_file(test_project_paths.MD_FILE)
assert file_graph_builder.supports_file(test_project_paths.PYTHON_FILE)
assert file_graph_builder.supports_file(test_project_paths.DUMMY_FILE) is False
def test_build_python_file_graph():
file_graph_builder = FileGraphBuilder(1000, 1000, 100)
parent_kg_node = KnowledgeGraphNode(0, None)
next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph(
parent_kg_node, test_project_paths.PYTHON_FILE, 0
)
assert next_node_id == 11
assert len(kg_nodes) == 11
assert len(kg_edges) == 11
# Test if some of the nodes exists
argument_list_ast_node = ASTNode(
type="argument_list", start_line=1, end_line=1, text='("Hello world!")'
)
string_ast_node = ASTNode(type="string", start_line=1, end_line=1, text='"Hello world!"')
found_argument_list_ast_node = False
for kg_node in kg_nodes:
if kg_node.node == argument_list_ast_node:
found_argument_list_ast_node = True
assert found_argument_list_ast_node
found_string_ast_node = False
for kg_node in kg_nodes:
if kg_node.node == string_ast_node:
found_string_ast_node = True
assert found_string_ast_node
# Test if some of the edges exists
found_edge = False
for kg_edge in kg_edges:
if (
kg_edge.source.node == argument_list_ast_node
and kg_edge.target.node == string_ast_node
and kg_edge.type == KnowledgeGraphEdgeType.parent_of
):
found_edge = True
assert found_edge
def test_build_text_file_graph():
file_graph_builder = FileGraphBuilder(1000, 100, 10)
parent_kg_node = KnowledgeGraphNode(0, None)
next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph(
parent_kg_node, test_project_paths.MD_FILE, 0
)
assert next_node_id == 2
assert len(kg_nodes) == 2
assert len(kg_edges) == 3
# Test if some of the nodes exists
text_node_1 = TextNode(
text="# A\n\nText under header A.\n\n## B\n\nText under header B.\n\n## C\n\nText under header C.\n\n### D",
start_line=1,
end_line=13,
)
text_node_2 = TextNode(text="### D\n\nText under header D.", start_line=13, end_line=15)
found_text_node_1 = False
for kg_node in kg_nodes:
if kg_node.node == text_node_1:
found_text_node_1 = True
assert found_text_node_1
found_text_node_2 = False
for kg_node in kg_nodes:
if kg_node.node == text_node_2:
found_text_node_2 = True
assert found_text_node_2
# Test if some of the edges exists
found_edge = False
for kg_edge in kg_edges:
if (
kg_edge.source.node == text_node_1
and kg_edge.target.node == text_node_2
and kg_edge.type == KnowledgeGraphEdgeType.next_chunk
):
found_edge = True
assert found_edge

View File

@@ -0,0 +1,272 @@
from prometheus.graph.graph_types import (
ASTNode,
FileNode,
KnowledgeGraphEdge,
KnowledgeGraphEdgeType,
KnowledgeGraphNode,
Neo4jASTNode,
Neo4jFileNode,
Neo4jTextNode,
TextNode,
)
def test_to_neo4j_file_node():
basename = "foo"
relative_path = "foo/bar/baz.py"
node_id = 1
file_node = FileNode(basename, relative_path)
knowldege_graph_node = KnowledgeGraphNode(node_id, file_node)
neo4j_file_node = knowldege_graph_node.to_neo4j_node()
assert isinstance(neo4j_file_node, dict)
assert "node_id" in neo4j_file_node
assert "basename" in neo4j_file_node
assert "relative_path" in neo4j_file_node
assert neo4j_file_node["node_id"] == node_id
assert neo4j_file_node["basename"] == basename
assert neo4j_file_node["relative_path"] == relative_path
def test_to_neo4j_ast_node():
type = "method_declaration"
start_line = 1
end_line = 5
text = "print('Hello world')"
node_id = 1
ast_node = ASTNode(type, start_line, end_line, text)
knowldege_graph_node = KnowledgeGraphNode(node_id, ast_node)
neo4j_ast_node = knowldege_graph_node.to_neo4j_node()
assert isinstance(neo4j_ast_node, dict)
assert "node_id" in neo4j_ast_node
assert "type" in neo4j_ast_node
assert "start_line" in neo4j_ast_node
assert "end_line" in neo4j_ast_node
assert "text" in neo4j_ast_node
assert neo4j_ast_node["node_id"] == node_id
assert neo4j_ast_node["type"] == type
assert neo4j_ast_node["start_line"] == start_line
assert neo4j_ast_node["end_line"] == end_line
assert neo4j_ast_node["text"] == text
def test_to_neo4j_text_node():
text = "Hello world"
node_id = 1
start_line = 1
end_line = 1
text_node = TextNode(text, start_line, end_line)
knowldege_graph_node = KnowledgeGraphNode(node_id, text_node)
neo4j_text_node = knowldege_graph_node.to_neo4j_node()
assert isinstance(neo4j_text_node, dict)
assert "node_id" in neo4j_text_node
assert "text" in neo4j_text_node
assert "start_line" in neo4j_text_node
assert "end_line" in neo4j_text_node
assert neo4j_text_node["node_id"] == node_id
assert neo4j_text_node["text"] == text
assert neo4j_text_node["start_line"] == start_line
assert neo4j_text_node["end_line"] == end_line
def test_to_neo4j_has_file_edge():
source_basename = "source"
source_relative_path = "foo/bar/source.py"
source_node_id = 1
target_basename = "target"
target_relative_path = "foo/bar/target.py"
target_node_id = 10
source_file_node = FileNode(source_basename, source_relative_path)
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node)
target_file_node = FileNode(target_basename, target_relative_path)
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_file_node)
knowledge_graph_edge = KnowledgeGraphEdge(
source_knowledge_graph_node,
target_knowledge_graph_node,
KnowledgeGraphEdgeType.has_file,
)
neo4j_has_file_edge = knowledge_graph_edge.to_neo4j_edge()
assert isinstance(neo4j_has_file_edge, dict)
assert "source" in neo4j_has_file_edge
assert "target" in neo4j_has_file_edge
assert neo4j_has_file_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
assert neo4j_has_file_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
def test_to_neo4j_has_ast_edge():
source_basename = "source"
source_relative_path = "foo/bar/source.py"
source_node_id = 1
target_type = "return_statement"
target_start_line = 7
target_end_line = 9
target_text = "return True"
target_node_id = 10
source_file_node = FileNode(source_basename, source_relative_path)
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node)
target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text)
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node)
knowledge_graph_edge = KnowledgeGraphEdge(
source_knowledge_graph_node,
target_knowledge_graph_node,
KnowledgeGraphEdgeType.has_ast,
)
neo4j_has_ast_edge = knowledge_graph_edge.to_neo4j_edge()
assert isinstance(neo4j_has_ast_edge, dict)
assert "source" in neo4j_has_ast_edge
assert "target" in neo4j_has_ast_edge
assert neo4j_has_ast_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
assert neo4j_has_ast_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
def test_to_neo4j_parent_of_edge():
source_type = "method_declaration"
source_start_line = 1
source_end_line = 5
source_text = "print('Hello world')"
source_node_id = 1
target_type = "return_statement"
target_start_line = 7
target_end_line = 9
target_text = "return True"
target_node_id = 10
source_ast_node = ASTNode(source_type, source_start_line, source_end_line, source_text)
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_ast_node)
target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text)
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node)
knowledge_graph_edge = KnowledgeGraphEdge(
source_knowledge_graph_node,
target_knowledge_graph_node,
KnowledgeGraphEdgeType.parent_of,
)
neo4j_parent_of_edge = knowledge_graph_edge.to_neo4j_edge()
assert isinstance(neo4j_parent_of_edge, dict)
assert "source" in neo4j_parent_of_edge
assert "target" in neo4j_parent_of_edge
assert neo4j_parent_of_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
assert neo4j_parent_of_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
def test_to_neo4j_has_text_edge():
source_basename = "source"
source_relative_path = "foo/bar/source.py"
source_node_id = 1
target_text = "Hello world"
target_start_line = 1
target_end_line = 1
target_node_id = 10
source_file_node = FileNode(source_basename, source_relative_path)
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node)
target_text_node = TextNode(target_text, target_start_line, target_end_line)
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node)
knowledge_graph_edge = KnowledgeGraphEdge(
source_knowledge_graph_node,
target_knowledge_graph_node,
KnowledgeGraphEdgeType.has_text,
)
neo4j_has_text_edge = knowledge_graph_edge.to_neo4j_edge()
assert isinstance(neo4j_has_text_edge, dict)
assert "source" in neo4j_has_text_edge
assert "target" in neo4j_has_text_edge
assert neo4j_has_text_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
assert neo4j_has_text_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
def test_to_neo4j_next_chunk_edge():
source_text = "Hello"
start_line = 1
end_line = 1
source_node_id = 1
target_text = "world"
target_start_line = 1
target_end_line = 1
target_node_id = 10
source_text_node = TextNode(source_text, start_line, end_line)
source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_text_node)
target_text_node = TextNode(target_text, target_start_line, target_end_line)
target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node)
knowledge_graph_edge = KnowledgeGraphEdge(
source_knowledge_graph_node,
target_knowledge_graph_node,
KnowledgeGraphEdgeType.has_text,
)
neo4j_next_chunk_edge = knowledge_graph_edge.to_neo4j_edge()
assert isinstance(neo4j_next_chunk_edge, dict)
assert "source" in neo4j_next_chunk_edge
assert "target" in neo4j_next_chunk_edge
assert neo4j_next_chunk_edge["source"] == source_knowledge_graph_node.to_neo4j_node()
assert neo4j_next_chunk_edge["target"] == target_knowledge_graph_node.to_neo4j_node()
def test_from_neo4j_file_node():
node_id = 10
basename = "foo.py"
relative_path = "bar/baz/foo.py"
neo4j_file_node = Neo4jFileNode(node_id=node_id, basename=basename, relative_path=relative_path)
knowledge_graph_node = KnowledgeGraphNode.from_neo4j_file_node(neo4j_file_node)
expected_file_node = FileNode(basename, relative_path)
expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_file_node)
assert knowledge_graph_node == expected_knowledge_graph_node
def test_from_neo4j_ast_node():
node_id = 15
type = "string_literal"
start_line = 5
end_line = 6
text = '"hello world"'
neo4j_ast_node = Neo4jASTNode(
node_id=node_id, type=type, start_line=start_line, end_line=end_line, text=text
)
knowledge_graph_node = KnowledgeGraphNode.from_neo4j_ast_node(neo4j_ast_node)
expected_ast_node = ASTNode(type, start_line, end_line, text)
expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_ast_node)
assert knowledge_graph_node == expected_knowledge_graph_node
def test_from_neo4j_text_node():
node_id = 20
text = "hello world"
start_line = 1
end_line = 1
neo4j_text_node = Neo4jTextNode(
node_id=node_id, text=text, start_line=start_line, end_line=end_line
)
knowledge_graph_node = KnowledgeGraphNode.from_neo4j_text_node(neo4j_text_node)
expected_text_node = TextNode(text, start_line, end_line)
expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_text_node)
assert knowledge_graph_node == expected_knowledge_graph_node

View File

@@ -0,0 +1,92 @@
import pytest
from prometheus.app.services.neo4j_service import Neo4jService
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
from tests.test_utils import test_project_paths
from tests.test_utils.fixtures import (
NEO4J_PASSWORD,
NEO4J_USERNAME,
neo4j_container_with_kg_fixture, # noqa: F401
)
@pytest.fixture
async def mock_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
neo4j_container, kg = neo4j_container_with_kg_fixture
service = Neo4jService(neo4j_container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
service.start()
yield service, kg
await service.close()
async def test_build_graph():
knowledge_graph = KnowledgeGraph(1, 1000, 100, 0)
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
assert knowledge_graph._next_node_id == 15
# 7 FileNode
# 84 ASTnode
# 2 TextNode
assert len(knowledge_graph._knowledge_graph_nodes) == 15
assert len(knowledge_graph._knowledge_graph_edges) == 14
assert len(knowledge_graph.get_file_nodes()) == 7
assert len(knowledge_graph.get_ast_nodes()) == 7
assert len(knowledge_graph.get_text_nodes()) == 1
assert len(knowledge_graph.get_parent_of_edges()) == 4
assert len(knowledge_graph.get_has_file_edges()) == 6
assert len(knowledge_graph.get_has_ast_edges()) == 3
assert len(knowledge_graph.get_has_text_edges()) == 1
assert len(knowledge_graph.get_next_chunk_edges()) == 0
async def test_get_file_tree():
knowledge_graph = KnowledgeGraph(1000, 1000, 100, 0)
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
file_tree = knowledge_graph.get_file_tree()
expected_file_tree = """\
test_project
├── bar
| ├── test.java
| └── test.py
├── foo
| └── test.md
└── test.c"""
assert file_tree == expected_file_tree
async def test_get_file_tree_depth_one():
knowledge_graph = KnowledgeGraph(1000, 1000, 100, 0)
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
file_tree = knowledge_graph.get_file_tree(max_depth=1)
expected_file_tree = """\
test_project
├── bar
├── foo
└── test.c"""
assert file_tree == expected_file_tree
async def test_get_file_tree_depth_two_max_seven_lines():
knowledge_graph = KnowledgeGraph(1000, 1000, 100, 0)
await knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH)
file_tree = knowledge_graph.get_file_tree(max_depth=2, max_lines=6)
expected_file_tree = """\
test_project
├── bar
| ├── test.java
| └── test.py
├── foo
| └── test.md"""
assert file_tree == expected_file_tree
@pytest.mark.slow
async def test_from_neo4j(mock_neo4j_service):
service, kg = mock_neo4j_service
handler = KnowledgeGraphHandler(service.neo4j_driver, 100)
read_kg = await handler.read_knowledge_graph(0, 1000, 100, 10)
assert read_kg == kg

View File

View File

@@ -0,0 +1,60 @@
from unittest.mock import Mock
import pytest
from langchain_core.language_models.chat_models import BaseChatModel
from prometheus.docker.base_container import BaseContainer
from prometheus.git.git_repository import GitRepository
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.graphs.issue_graph import IssueGraph
@pytest.fixture
def mock_advanced_model():
return Mock(spec=BaseChatModel)
@pytest.fixture
def mock_base_model():
return Mock(spec=BaseChatModel)
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
kg.root_node_id = 0
return kg
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.playground_path = "mock/playground/path"
return git_repo
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
def test_issue_graph_basic_initialization(
mock_advanced_model,
mock_base_model,
mock_kg,
mock_git_repo,
mock_container,
):
"""Test that IssueGraph initializes correctly with basic components."""
graph = IssueGraph(
advanced_model=mock_advanced_model,
base_model=mock_base_model,
kg=mock_kg,
git_repo=mock_git_repo,
container=mock_container,
repository_id=1,
)
assert graph.graph is not None
assert graph.git_repo == mock_git_repo

View File

@@ -0,0 +1,137 @@
from langchain_core.messages import HumanMessage
from prometheus.lang_graph.nodes.add_context_refined_query_message_node import (
AddContextRefinedQueryMessageNode,
)
from prometheus.models.query import Query
def test_add_context_refined_query_message_node_with_all_fields():
"""Test node with all fields populated in refined_query."""
node = AddContextRefinedQueryMessageNode()
refined_query = Query(
essential_query="Find all authentication logic",
extra_requirements="Include error handling and validation",
purpose="Security audit",
)
state = {"refined_query": refined_query}
result = node(state)
assert "context_provider_messages" in result
assert len(result["context_provider_messages"]) == 1
assert isinstance(result["context_provider_messages"][0], HumanMessage)
message_content = result["context_provider_messages"][0].content
assert "Essential query: Find all authentication logic" in message_content
assert "Extra requirements: Include error handling and validation" in message_content
assert "Purpose: Security audit" in message_content
def test_add_context_refined_query_message_node_essential_query_only():
"""Test node with only essential_query populated."""
node = AddContextRefinedQueryMessageNode()
refined_query = Query(
essential_query="Locate the main entry point",
extra_requirements="",
purpose="",
)
state = {"refined_query": refined_query}
result = node(state)
assert "context_provider_messages" in result
assert len(result["context_provider_messages"]) == 1
message_content = result["context_provider_messages"][0].content
assert "Essential query: Locate the main entry point" in message_content
assert "Extra requirements:" not in message_content
assert "Purpose:" not in message_content
def test_add_context_refined_query_message_node_with_extra_requirements_only():
"""Test node with essential_query and extra_requirements only."""
node = AddContextRefinedQueryMessageNode()
refined_query = Query(
essential_query="Find database queries",
extra_requirements="Focus on SQL injection vulnerabilities",
purpose="",
)
state = {"refined_query": refined_query}
result = node(state)
assert "context_provider_messages" in result
message_content = result["context_provider_messages"][0].content
assert "Essential query: Find database queries" in message_content
assert "Extra requirements: Focus on SQL injection vulnerabilities" in message_content
assert "Purpose:" not in message_content
def test_add_context_refined_query_message_node_with_purpose_only():
"""Test node with essential_query and purpose only."""
node = AddContextRefinedQueryMessageNode()
refined_query = Query(
essential_query="Identify all API endpoints",
extra_requirements="",
purpose="Documentation generation",
)
state = {"refined_query": refined_query}
result = node(state)
assert "context_provider_messages" in result
message_content = result["context_provider_messages"][0].content
assert "Essential query: Identify all API endpoints" in message_content
assert "Extra requirements:" not in message_content
assert "Purpose: Documentation generation" in message_content
def test_add_context_refined_query_message_node_returns_list():
"""Test that node returns a list with exactly one HumanMessage."""
node = AddContextRefinedQueryMessageNode()
refined_query = Query(
essential_query="Test query",
extra_requirements="Test requirements",
purpose="Test purpose",
)
state = {"refined_query": refined_query}
result = node(state)
assert isinstance(result["context_provider_messages"], list)
assert len(result["context_provider_messages"]) == 1
assert isinstance(result["context_provider_messages"][0], HumanMessage)
def test_add_context_refined_query_message_node_message_format():
"""Test the exact format of the constructed message."""
node = AddContextRefinedQueryMessageNode()
refined_query = Query(
essential_query="Query text",
extra_requirements="Requirements text",
purpose="Purpose text",
)
state = {"refined_query": refined_query}
result = node(state)
expected_content = (
"Essential query: Query text\nExtra requirements: Requirements text\nPurpose: Purpose text"
)
assert result["context_provider_messages"][0].content == expected_content

View File

@@ -0,0 +1,52 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from prometheus.docker.base_container import BaseContainer
from prometheus.lang_graph.nodes.bug_fix_verify_node import BugFixVerifyNode
from prometheus.lang_graph.subgraphs.bug_fix_verification_state import BugFixVerificationState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def test_state():
return BugFixVerificationState(
{
"reproduced_bug_file": "test_bug.py",
"reproduced_bug_commands": ["python test_bug.py", "./run_test.sh"],
"bug_fix_verify_messages": [AIMessage(content="Previous verification result")],
}
)
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(responses=["Test execution completed"])
def test_format_human_message(mock_container, fake_llm, test_state):
"""Test human message formatting."""
node = BugFixVerifyNode(fake_llm, mock_container)
message = node.format_human_message(test_state)
assert isinstance(message, HumanMessage)
assert "test_bug.py" in message.content
assert "python test_bug.py" in message.content
assert "./run_test.sh" in message.content
def test_call_method(mock_container, fake_llm, test_state):
"""Test the __call__ method execution."""
node = BugFixVerifyNode(fake_llm, mock_container)
result = node(test_state)
assert "bug_fix_verify_messages" in result
assert len(result["bug_fix_verify_messages"]) == 1
assert result["bug_fix_verify_messages"][0].content == "Test execution completed"

View File

@@ -0,0 +1,74 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from prometheus.docker.base_container import BaseContainer
from prometheus.lang_graph.nodes.bug_reproducing_execute_node import BugReproducingExecuteNode
from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def test_state():
return BugReproductionState(
{
"issue_title": "Test Bug",
"issue_body": "Bug description",
"issue_comments": ["Comment 1", "Comment 2"],
"bug_context": "Context of the bug",
"bug_reproducing_write_messages": [AIMessage(content="patch")],
"bug_reproducing_file_messages": [AIMessage(content="path")],
"bug_reproducing_execute_messages": [],
"bug_reproducing_patch": "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content",
}
)
def test_format_human_message_with_test_commands(mock_container, test_state):
"""Test message formatting with provided test commands."""
fake_llm = FakeListChatWithToolsModel(responses=["test"])
test_commands = ["pytest", "python -m unittest"]
node = BugReproducingExecuteNode(fake_llm, mock_container, test_commands)
message = node.format_human_message(test_state, "/foo/bar/test.py")
assert isinstance(message, HumanMessage)
assert "Test Bug" in message.content
assert "Bug description" in message.content
assert "Comment 1" in message.content
assert "Comment 2" in message.content
assert "pytest" in message.content
assert "/foo/bar/test.py" in message.content
def test_format_human_message_without_test_commands(mock_container, test_state):
"""Test message formatting without test commands."""
fake_llm = FakeListChatWithToolsModel(responses=["test"])
node = BugReproducingExecuteNode(fake_llm, mock_container)
message = node.format_human_message(test_state, "/foo/bar/test.py")
assert isinstance(message, HumanMessage)
assert "Test Bug" in message.content
assert "Bug description" in message.content
assert "User provided test commands:\n" in message.content
assert "/foo/bar/test.py" in message.content
def test_call_method(mock_container, test_state):
"""Test the __call__ method execution."""
fake_response = "Test execution completed"
fake_llm = FakeListChatWithToolsModel(responses=[fake_response])
node = BugReproducingExecuteNode(fake_llm, mock_container)
result = node(test_state)
assert "bug_reproducing_execute_messages" in result
assert len(result["bug_reproducing_execute_messages"]) == 1
assert result["bug_reproducing_execute_messages"][0].content == fake_response

View File

@@ -0,0 +1,68 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.nodes.bug_reproducing_file_node import BugReproducingFileNode
from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
kg.get_file_tree.return_value = "test_dir/\n test_file.py"
return kg
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(responses=["test_output.py"])
@pytest.fixture
def basic_state():
return BugReproductionState(
issue_title="mock issue title",
issue_body="mock issue body",
issue_comments=[],
max_refined_query_loop=3,
bug_reproducing_query="mock query",
bug_reproducing_context=[],
bug_reproducing_patch="",
bug_reproducing_write_messages=[AIMessage(content="def test_bug():\n assert 1 == 2")],
bug_reproducing_file_messages=[],
bug_reproducing_execute_messages=[],
reproduced_bug=False,
reproduced_bug_failure_log="",
reproduced_bug_file="",
reproduced_bug_commands=[],
)
def test_initialization(mock_kg, fake_llm):
"""Test basic initialization of BugReproducingFileNode."""
node = BugReproducingFileNode(fake_llm, mock_kg, "test/path")
assert isinstance(node.system_prompt, SystemMessage)
assert len(node.tools) == 2 # read_file, create_file
def test_format_human_message(mock_kg, fake_llm, basic_state):
"""Test human message formatting with bug file."""
node = BugReproducingFileNode(fake_llm, mock_kg, "test/path")
message = node.format_human_message(basic_state)
assert isinstance(message, HumanMessage)
assert "def test_bug():" in message.content
def test_call_method(mock_kg, fake_llm, basic_state):
"""Test the __call__ method execution."""
node = BugReproducingFileNode(fake_llm, mock_kg, "test/path")
result = node(basic_state)
assert "bug_reproducing_file_messages" in result
assert len(result["bug_reproducing_file_messages"]) == 1
assert result["bug_reproducing_file_messages"][0].content == "test_output.py"

View File

@@ -0,0 +1,49 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import HumanMessage
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.nodes.bug_reproducing_write_node import BugReproducingWriteNode
from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState
from tests.test_utils.fixtures import temp_test_dir # noqa: F401
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
return kg
@pytest.fixture
def test_state():
return BugReproductionState(
issue_title="Test Bug",
issue_body="Bug description",
issue_comments=[{"user1": "Comment 1"}, {"user2": "Comment 2"}],
max_refined_query_loop=3,
bug_reproducing_query="mock query",
bug_reproducing_context=[],
bug_reproducing_write_messages=[HumanMessage("assert x == 10")],
bug_reproducing_file_messages=[],
bug_reproducing_execute_messages=[],
bug_reproducing_patch="",
reproduced_bug=False,
reproduced_bug_failure_log="Test failure log",
reproduced_bug_file="test/file.py",
reproduced_bug_commands=[],
)
def test_call_method(mock_kg, test_state, temp_test_dir): # noqa: F811
"""Test the __call__ method execution."""
fake_response = "Created test file"
fake_llm = FakeListChatWithToolsModel(responses=[fake_response])
node = BugReproducingWriteNode(fake_llm, temp_test_dir, mock_kg)
result = node(test_state)
assert "bug_reproducing_write_messages" in result
assert len(result["bug_reproducing_write_messages"]) == 1
assert result["bug_reproducing_write_messages"][0].content == fake_response

View File

@@ -0,0 +1,41 @@
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.nodes.context_provider_node import ContextProviderNode
from tests.test_utils import test_project_paths
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture(scope="function")
async def knowledge_graph_fixture():
kg = KnowledgeGraph(1000, 100, 10, 0)
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
return kg
@pytest.mark.slow
async def test_context_provider_node_basic_query(knowledge_graph_fixture):
"""Test basic query handling with the ContextProviderNode."""
fake_response = "Fake response"
fake_llm = FakeListChatWithToolsModel(responses=[fake_response])
node = ContextProviderNode(
model=fake_llm,
kg=knowledge_graph_fixture,
local_path=test_project_paths.TEST_PROJECT_PATH,
)
test_messages = [
AIMessage(content="This code handles file processing"),
ToolMessage(content="Found implementation in utils.py", tool_call_id="test_tool_call_1"),
]
test_state = {
"original_query": "How does the error handling work?",
"context_provider_messages": test_messages,
}
result = node(test_state)
assert "context_provider_messages" in result
assert len(result["context_provider_messages"]) == 1
assert result["context_provider_messages"][0].content == fake_response

View File

@@ -0,0 +1,127 @@
from unittest.mock import patch
import pytest
from langchain_core.messages import HumanMessage
from prometheus.lang_graph.nodes.edit_message_node import EditMessageNode
from prometheus.models.context import Context
@pytest.fixture
def edit_node():
return EditMessageNode(
context_key="bug_fix_context", analyzer_message_key="issue_bug_analyzer_messages"
)
@pytest.fixture
def base_state():
return {
"issue_title": "Test Bug",
"issue_body": "This is a test bug description",
"issue_comments": [
{"username": "user1", "comment": "Comment 1"},
{"username": "user2", "comment": "Comment 2"},
],
"bug_fix_context": [
Context(
relative_path="foobar.py",
content="# Context 1",
start_line_number=1,
end_line_number=1,
)
],
"issue_bug_analyzer_messages": ["Analysis message"],
}
def test_first_message_formatting(edit_node, base_state):
# Using context managers for patching
with patch(
"prometheus.lang_graph.nodes.edit_message_node.format_issue_info"
) as mock_format_issue:
with patch(
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
) as mock_last_message:
mock_format_issue.return_value = "Formatted Issue Info"
mock_last_message.return_value = "Last Analysis Message"
result = edit_node(base_state)
assert isinstance(result, dict)
assert "edit_messages" in result
assert len(result["edit_messages"]) == 1
assert isinstance(result["edit_messages"][0], HumanMessage)
message_content = result["edit_messages"][0].content
assert "Formatted Issue Info" in message_content
assert "# Context 1" in message_content
assert "Last Analysis Message" in message_content
def test_followup_message_with_build_fail(edit_node, base_state):
# Add build failure to state
base_state["build_fail_log"] = "Build failed: error in compilation"
with patch(
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
) as mock_last_message:
mock_last_message.return_value = "Last Analysis Message"
result = edit_node(base_state)
message_content = result["edit_messages"][0].content
assert "Build failed: error in compilation" in message_content
assert "Please implement these revised changes carefully" in message_content
def test_followup_message_with_test_fail(edit_node, base_state):
# Add test failure to state
base_state["reproducing_test_fail_log"] = "Test failed: assertion error"
with patch(
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
) as mock_last_message:
mock_last_message.return_value = "Last Analysis Message"
result = edit_node(base_state)
message_content = result["edit_messages"][0].content
assert "Test failed: assertion error" in message_content
assert "Please implement these revised changes carefully" in message_content
def test_followup_message_with_existing_test_fail(edit_node, base_state):
# Add existing test failure to state
base_state["existing_test_fail_log"] = "Existing test failed"
with patch(
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
) as mock_last_message:
mock_last_message.return_value = "Last Analysis Message"
result = edit_node(base_state)
message_content = result["edit_messages"][0].content
assert "Existing test failed" in message_content
assert "Please implement these revised changes carefully" in message_content
def test_error_priority(edit_node, base_state):
# Add multiple error types to test priority handling
base_state["reproducing_test_fail_log"] = "Test failed"
base_state["build_fail_log"] = "Build failed"
base_state["existing_test_fail_log"] = "Existing test failed"
with patch(
"prometheus.lang_graph.nodes.edit_message_node.get_last_message_content"
) as mock_last_message:
mock_last_message.return_value = "Last Analysis Message"
result = edit_node(base_state)
message_content = result["edit_messages"][0].content
# Should prioritize reproducing test failure
assert "Test failed" in message_content
assert "Build failed" not in message_content
assert "Existing test failed" not in message_content

View File

@@ -0,0 +1,41 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.nodes.edit_node import EditNode
from tests.test_utils.fixtures import temp_test_dir # noqa: F401
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
return kg
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(responses=["File edit completed successfully"])
def test_init_edit_node(mock_kg, fake_llm, temp_test_dir): # noqa: F811
"""Test EditNode initialization."""
node = EditNode(fake_llm, temp_test_dir, mock_kg)
assert isinstance(node.system_prompt, SystemMessage)
assert len(node.tools) == 5 # Should have 5 file operation tools
assert node.model_with_tools is not None
def test_call_method_basic(mock_kg, fake_llm, temp_test_dir): # noqa: F811
"""Test basic call functionality without tool execution."""
node = EditNode(fake_llm, temp_test_dir, mock_kg)
state = {"edit_messages": [HumanMessage(content="Make the following changes: ...")]}
result = node(state)
assert "edit_messages" in result
assert len(result["edit_messages"]) == 1
assert result["edit_messages"][0].content == "File edit completed successfully"

View File

@@ -0,0 +1,64 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import HumanMessage
from prometheus.docker.base_container import BaseContainer
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.nodes.general_build_node import GeneralBuildNode
from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
kg.get_file_tree.return_value = ".\n├── src\n│ └── main.py\n└── build.gradle"
return kg
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(responses=["Build command executed successfully"])
def test_format_human_message_basic(mock_container, mock_kg, fake_llm):
"""Test basic human message formatting."""
node = GeneralBuildNode(fake_llm, mock_container, mock_kg)
state = BuildAndTestState({})
message = node.format_human_message(state)
assert isinstance(message, HumanMessage)
assert "project structure is:" in message.content
assert mock_kg.get_file_tree() in message.content
def test_format_human_message_with_build_summary(mock_container, mock_kg, fake_llm):
"""Test message formatting with build command summary."""
node = GeneralBuildNode(fake_llm, mock_container, mock_kg)
state = BuildAndTestState({"build_command_summary": "Previous build used gradle"})
message = node.format_human_message(state)
assert "Previous build used gradle" in message.content
assert "The previous build summary is:" in message.content
def test_call_method_with_no_build(mock_container, mock_kg, fake_llm):
"""Test __call__ method when exist_build is False."""
node = GeneralBuildNode(fake_llm, mock_container, mock_kg)
state = BuildAndTestState({"exist_build": False})
result = node(state)
assert "build_messages" in result
assert len(result["build_messages"]) == 1
assert (
"Previous agent determined there is no build system" in result["build_messages"][0].content
)

View File

@@ -0,0 +1,75 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import HumanMessage
from prometheus.docker.base_container import BaseContainer
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.nodes.general_test_node import GeneralTestNode
from prometheus.lang_graph.subgraphs.build_and_test_state import BuildAndTestState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
kg.get_file_tree.return_value = "./\n├── tests/\n│ └── test_main.py"
return kg
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(responses=["Tests executed successfully"])
@pytest.fixture
def basic_state():
return BuildAndTestState(
{"exist_test": True, "test_messages": [], "test_command_summary": None}
)
def test_format_human_message(mock_container, mock_kg, fake_llm, basic_state):
"""Test basic message formatting."""
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
message = node.format_human_message(basic_state)
assert isinstance(message, HumanMessage)
assert mock_kg.get_file_tree() in message.content
def test_call_with_no_tests(mock_container, mock_kg, fake_llm):
"""Test behavior when no tests exist."""
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
state = BuildAndTestState({"exist_test": False})
result = node(state)
assert "build_messages" in result
assert "no test framework" in result["build_messages"][0].content
def test_call_normal_execution(mock_container, mock_kg, fake_llm, basic_state):
"""Test normal execution flow."""
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
result = node(basic_state)
assert "test_messages" in result
assert len(result["test_messages"]) == 1
assert result["test_messages"][0].content == "Tests executed successfully"
def test_format_human_message_with_summary(mock_container, mock_kg, fake_llm):
"""Test message formatting with test summary."""
node = GeneralTestNode(fake_llm, mock_container, mock_kg)
state = BuildAndTestState({"test_command_summary": "Previous test used pytest"})
message = node.format_human_message(state)
assert "Previous test used pytest" in message.content

View File

@@ -0,0 +1,35 @@
from unittest.mock import Mock
import pytest
from prometheus.git.git_repository import GitRepository
from prometheus.lang_graph.nodes.git_diff_node import GitDiffNode
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.get_diff.return_value = "sample diff content"
return git_repo
def test_git_diff_node(mock_git_repo):
node = GitDiffNode(mock_git_repo, "patch")
# Execute
result = node({})
# Assert
assert result == {"patch": "sample diff content"}
mock_git_repo.get_diff.assert_called_with(None)
def test_git_diff_node_with_excluded_files(mock_git_repo):
node = GitDiffNode(mock_git_repo, "patch", "excluded_file")
# Execute
result = node({"excluded_file": "/foo/bar.py"})
# Assert
assert result == {"patch": "sample diff content"}
mock_git_repo.get_diff.assert_called_with(["/foo/bar.py"])

View File

@@ -0,0 +1,135 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages.tool import ToolCall
from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(responses=["Bug analysis completed successfully"])
@pytest.fixture
def fake_llm_with_tool_call():
"""LLM that simulates making a web_search tool call."""
return FakeListChatWithToolsModel(
responses=["I need to search for information about this error."]
)
def test_init_issue_bug_analyzer_node(fake_llm):
"""Test IssueBugAnalyzerNode initialization."""
node = IssueBugAnalyzerNode(fake_llm)
assert node.system_prompt is not None
assert len(node.tools) == 1 # Should have web_search tool
assert node.tools[0].name == "web_search"
assert node.model_with_tools is not None
def test_call_method_basic(fake_llm):
"""Test basic call functionality."""
node = IssueBugAnalyzerNode(fake_llm)
state = {"issue_bug_analyzer_messages": [HumanMessage(content="Please analyze this bug: ...")]}
result = node(state)
assert "issue_bug_analyzer_messages" in result
assert len(result["issue_bug_analyzer_messages"]) == 1
assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully"
def test_web_search_tool_integration(fake_llm_with_tool_call):
"""Test that the web_search tool is properly integrated and can be called."""
node = IssueBugAnalyzerNode(fake_llm_with_tool_call)
state = {
"issue_bug_analyzer_messages": [
HumanMessage(
content="I'm getting a ValueError in my Python code. Can you help analyze it?"
)
]
}
result = node(state)
# Verify the result contains the response message
assert "issue_bug_analyzer_messages" in result
assert len(result["issue_bug_analyzer_messages"]) == 1
assert (
result["issue_bug_analyzer_messages"][0].content
== "I need to search for information about this error."
)
def test_web_search_tool_call_with_correct_parameters(fake_llm):
"""Test that web_search tool has correct configuration and can be called."""
node = IssueBugAnalyzerNode(fake_llm)
# Test that the tool exists and has correct configuration
web_search_tool = node.tools[0]
assert web_search_tool.name == "web_search"
assert "technical information" in web_search_tool.description.lower()
# Test that the tool has the correct args schema
assert hasattr(web_search_tool, "args_schema")
assert web_search_tool.args_schema is not None
def test_system_prompt_contains_web_search_info(fake_llm):
"""Test that the system prompt mentions web_search tool."""
node = IssueBugAnalyzerNode(fake_llm)
system_prompt_content = node.system_prompt.content.lower()
assert "web_search" in system_prompt_content
assert "technical information" in system_prompt_content
def test_web_search_tool_schema_validation(fake_llm):
"""Test that the web_search tool has proper input validation."""
node = IssueBugAnalyzerNode(fake_llm)
web_search_tool = node.tools[0]
# Check that the tool has an args_schema
assert hasattr(web_search_tool, "args_schema")
assert web_search_tool.args_schema is not None
# Test with valid input
valid_input = {"query": "Python debugging techniques"}
# This should not raise an exception
validated_input = web_search_tool.args_schema(**valid_input)
assert validated_input.query == "Python debugging techniques"
def test_multiple_tool_calls_in_conversation(fake_llm):
"""Test handling multiple web_search calls in a conversation."""
node = IssueBugAnalyzerNode(fake_llm)
# Simulate a conversation with tool calls
state = {
"issue_bug_analyzer_messages": [
HumanMessage(content="Analyze this bug: ImportError in my application"),
AIMessage(
content="Let me search for information about this error.",
tool_calls=[
ToolCall(
name="web_search",
args={"query": "Python ImportError debugging"},
id="call_1",
)
],
),
ToolMessage(
content="Search results: ImportError occurs when...", tool_call_id="call_1"
),
HumanMessage(content="The error still persists after trying the suggested fixes"),
]
}
result = node(state)
assert "issue_bug_analyzer_messages" in result
assert len(result["issue_bug_analyzer_messages"]) == 1
# The new response should be added to the conversation
assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully"

View File

@@ -0,0 +1,113 @@
import pytest
from prometheus.lang_graph.nodes.issue_bug_responder_node import IssueBugResponderNode
from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(
responses=["Thank you for reporting this issue. The fix has been implemented and verified."]
)
@pytest.fixture
def basic_state():
return IssueBugState(
issue_title="Test Bug",
issue_body="Found a bug in the code",
issue_comments=[
{"username": "user1", "comment": "This affects my workflow"},
{"username": "user2", "comment": "Same issue here"},
],
edit_patch="Fixed array index calculation",
passed_reproducing_test=True,
passed_regression_test=True,
passed_existing_test=True,
run_build=True,
run_existing_test=True,
run_regression_test=True,
run_reproduce_test=True,
number_of_candidate_patch=6,
reproduced_bug=True,
reproduced_bug_file="mock.py",
reproduced_bug_patch="mock patch to reproduce the bug",
reproduced_bug_commands="pytest test_bug.py",
selected_regression_tests=["tests:tests"],
issue_response="Mock Response",
)
def test_format_human_message_basic(fake_llm, basic_state):
"""Test basic human message formatting."""
node = IssueBugResponderNode(fake_llm)
message = node.format_human_message(basic_state)
assert "Test Bug" in message.content
assert "Found a bug in the code" in message.content
assert "user1" in message.content
assert "user2" in message.content
assert "Fixed array index" in message.content
def test_format_human_message_verification(fake_llm, basic_state):
"""Test verification message formatting."""
node = IssueBugResponderNode(fake_llm)
message = node.format_human_message(basic_state)
assert "✓ The bug reproducing test passed" in message.content
assert "✓ All selected regression tests passes successfully" in message.content
assert "✓ All existing tests pass successfully" in message.content
def test_format_human_message_no_verification(fake_llm):
"""Test message formatting without verifications."""
state = IssueBugState(
issue_title="Test Bug",
issue_body="Bug description",
issue_comments=[],
edit_patch="Fixed array index calculation",
passed_reproducing_test=False,
passed_existing_test=False,
passed_regression_test=False,
)
node = IssueBugResponderNode(fake_llm)
message = node.format_human_message(state)
assert "✓ The bug reproducing test passed" not in message.content
assert "✓ All selected regression tests passes successfully" not in message.content
assert "✓ All existing tests pass successfully" not in message.content
def test_format_human_message_partial_verification(fake_llm):
"""Test message formatting with partial verifications."""
state = IssueBugState(
issue_title="Test Bug",
issue_body="Bug description",
issue_comments=[],
edit_patch="Fixed array index calculation",
passed_reproducing_test=True,
passed_existing_test=True,
passed_regression_test=True,
)
node = IssueBugResponderNode(fake_llm)
message = node.format_human_message(state)
assert "✓ The bug reproducing test passed" in message.content
assert "✓ Build passes successfully" not in message.content
assert "✓ All existing tests pass successfully" in message.content
def test_call_method(fake_llm, basic_state):
"""Test the call method execution."""
node = IssueBugResponderNode(fake_llm)
result = node(basic_state)
assert "issue_response" in result
assert (
result["issue_response"]
== "Thank you for reporting this issue. The fix has been implemented and verified."
)

View File

@@ -0,0 +1,98 @@
import pytest
from langchain_core.messages import HumanMessage
from prometheus.lang_graph.nodes.issue_documentation_analyzer_message_node import (
IssueDocumentationAnalyzerMessageNode,
)
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
from prometheus.models.context import Context
@pytest.fixture
def basic_state():
return IssueDocumentationState(
issue_title="Update API documentation",
issue_body="The API documentation needs to be updated",
issue_comments=[
{"username": "user1", "comment": "Please add examples"},
],
max_refined_query_loop=3,
documentation_query="Find API docs",
documentation_context=[
Context(
relative_path="/docs/api.md",
content="# API Documentation\n\nAPI Documentation content",
)
],
issue_documentation_analyzer_messages=[],
edit_messages=[],
edit_patch="",
issue_response="",
)
def test_init_issue_documentation_analyzer_message_node():
"""Test IssueDocumentationAnalyzerMessageNode initialization."""
node = IssueDocumentationAnalyzerMessageNode()
assert node is not None
def test_call_method_creates_message(basic_state):
"""Test that the node creates a human message."""
node = IssueDocumentationAnalyzerMessageNode()
result = node(basic_state)
assert "issue_documentation_analyzer_messages" in result
assert len(result["issue_documentation_analyzer_messages"]) == 1
assert isinstance(result["issue_documentation_analyzer_messages"][0], HumanMessage)
def test_message_contains_issue_info(basic_state):
"""Test that the message contains issue information."""
node = IssueDocumentationAnalyzerMessageNode()
result = node(basic_state)
message_content = result["issue_documentation_analyzer_messages"][0].content
assert "Update API documentation" in message_content
def test_message_contains_context(basic_state):
"""Test that the message includes documentation context."""
node = IssueDocumentationAnalyzerMessageNode()
result = node(basic_state)
message_content = result["issue_documentation_analyzer_messages"][0].content
# Should include context or reference to it
assert "context" in message_content.lower() or "API Documentation" in message_content
def test_message_includes_analysis_instructions(basic_state):
"""Test that the message includes analysis instructions."""
node = IssueDocumentationAnalyzerMessageNode()
result = node(basic_state)
message_content = result["issue_documentation_analyzer_messages"][0].content
# Should include instructions for analysis
assert "plan" in message_content.lower() or "analyze" in message_content.lower()
def test_call_with_empty_context():
"""Test the node with empty documentation context."""
state = IssueDocumentationState(
issue_title="Create new docs",
issue_body="Create documentation for new feature",
issue_comments=[],
max_refined_query_loop=3,
documentation_query="",
documentation_context=[],
issue_documentation_analyzer_messages=[],
edit_messages=[],
edit_patch="",
issue_response="",
)
node = IssueDocumentationAnalyzerMessageNode()
result = node(state)
assert "issue_documentation_analyzer_messages" in result
assert len(result["issue_documentation_analyzer_messages"]) == 1

View File

@@ -0,0 +1,81 @@
import pytest
from langchain_core.messages import HumanMessage
from prometheus.lang_graph.nodes.issue_documentation_analyzer_node import (
IssueDocumentationAnalyzerNode,
)
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(
responses=[
"Documentation Plan:\n1. Update README.md with new API documentation\n"
"2. Add code examples\n3. Update table of contents"
]
)
@pytest.fixture
def basic_state():
return IssueDocumentationState(
issue_title="Update API documentation",
issue_body="The API documentation needs to be updated with the new endpoints",
issue_comments=[
{"username": "user1", "comment": "Please include examples"},
{"username": "user2", "comment": "Add authentication details"},
],
max_refined_query_loop=3,
documentation_query="Find API documentation files",
documentation_context=[],
issue_documentation_analyzer_messages=[
HumanMessage(content="Please analyze the documentation request and provide a plan")
],
edit_messages=[],
edit_patch="",
issue_response="",
)
def test_init_issue_documentation_analyzer_node(fake_llm):
"""Test IssueDocumentationAnalyzerNode initialization."""
node = IssueDocumentationAnalyzerNode(fake_llm)
assert node.system_prompt is not None
assert node.web_search_tool is not None
assert len(node.tools) == 1 # Should have web search tool
assert node.model_with_tools is not None
def test_call_method_basic(fake_llm, basic_state):
"""Test basic call functionality."""
node = IssueDocumentationAnalyzerNode(fake_llm)
result = node(basic_state)
assert "issue_documentation_analyzer_messages" in result
assert len(result["issue_documentation_analyzer_messages"]) == 1
assert "Documentation Plan" in result["issue_documentation_analyzer_messages"][0].content
def test_call_method_with_empty_messages(fake_llm):
"""Test call method with empty message history."""
state = IssueDocumentationState(
issue_title="Test",
issue_body="Test body",
issue_comments=[],
max_refined_query_loop=3,
documentation_query="",
documentation_context=[],
issue_documentation_analyzer_messages=[],
edit_messages=[],
edit_patch="",
issue_response="",
)
node = IssueDocumentationAnalyzerNode(fake_llm)
result = node(state)
assert "issue_documentation_analyzer_messages" in result
assert len(result["issue_documentation_analyzer_messages"]) == 1

View File

@@ -0,0 +1,80 @@
import pytest
from prometheus.lang_graph.nodes.issue_documentation_context_message_node import (
IssueDocumentationContextMessageNode,
)
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
@pytest.fixture
def basic_state():
return IssueDocumentationState(
issue_title="Update API documentation",
issue_body="The API documentation needs to be updated with new endpoints",
issue_comments=[
{"username": "user1", "comment": "Please include examples"},
],
max_refined_query_loop=3,
documentation_query="",
documentation_context=[],
issue_documentation_analyzer_messages=[],
edit_messages=[],
edit_patch="",
issue_response="",
)
def test_init_issue_documentation_context_message_node():
"""Test IssueDocumentationContextMessageNode initialization."""
node = IssueDocumentationContextMessageNode()
assert node is not None
def test_call_method_generates_query(basic_state):
"""Test that the node generates a documentation query."""
node = IssueDocumentationContextMessageNode()
result = node(basic_state)
assert "documentation_query" in result
assert len(result["documentation_query"]) > 0
def test_query_contains_issue_info(basic_state):
"""Test that the query contains issue information."""
node = IssueDocumentationContextMessageNode()
result = node(basic_state)
query = result["documentation_query"]
assert "Update API documentation" in query or "API documentation" in query
def test_query_includes_instructions(basic_state):
"""Test that the query includes documentation finding instructions."""
node = IssueDocumentationContextMessageNode()
result = node(basic_state)
query = result["documentation_query"]
# Should include instructions about finding documentation
assert "documentation" in query.lower() or "find" in query.lower()
def test_call_with_empty_comments():
"""Test the node with empty comments."""
state = IssueDocumentationState(
issue_title="Test title",
issue_body="Test body",
issue_comments=[],
max_refined_query_loop=3,
documentation_query="",
documentation_context=[],
issue_documentation_analyzer_messages=[],
edit_messages=[],
edit_patch="",
issue_response="",
)
node = IssueDocumentationContextMessageNode()
result = node(state)
assert "documentation_query" in result
assert len(result["documentation_query"]) > 0

View File

@@ -0,0 +1,117 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from prometheus.lang_graph.nodes.issue_documentation_edit_message_node import (
IssueDocumentationEditMessageNode,
)
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
from prometheus.models.context import Context
@pytest.fixture
def basic_state():
return IssueDocumentationState(
issue_title="Update API documentation",
issue_body="The API documentation needs to be updated",
issue_comments=[],
max_refined_query_loop=3,
documentation_query="Find API docs",
documentation_context=[
Context(
relative_path="/docs/api.md",
content="# API Documentation\n\nAPI Documentation content",
)
],
issue_documentation_analyzer_messages=[
AIMessage(content="Plan:\n1. Update README.md\n2. Add new examples\n3. Fix typos")
],
edit_messages=[],
edit_patch="",
issue_response="",
)
def test_init_issue_documentation_edit_message_node():
"""Test IssueDocumentationEditMessageNode initialization."""
node = IssueDocumentationEditMessageNode()
assert node is not None
def test_call_method_creates_message(basic_state):
"""Test that the node creates a human message."""
node = IssueDocumentationEditMessageNode()
result = node(basic_state)
assert "edit_messages" in result
assert len(result["edit_messages"]) == 1
assert isinstance(result["edit_messages"][0], HumanMessage)
def test_message_contains_plan(basic_state):
"""Test that the message contains the documentation plan."""
node = IssueDocumentationEditMessageNode()
result = node(basic_state)
message_content = result["edit_messages"][0].content
assert "Plan:" in message_content or "Update README.md" in message_content
def test_message_contains_context(basic_state):
"""Test that the message includes documentation context."""
node = IssueDocumentationEditMessageNode()
result = node(basic_state)
message_content = result["edit_messages"][0].content
# Should include context
assert "context" in message_content.lower() or "API Documentation" in message_content
def test_message_includes_edit_instructions(basic_state):
"""Test that the message includes editing instructions."""
node = IssueDocumentationEditMessageNode()
result = node(basic_state)
message_content = result["edit_messages"][0].content
# Should include instructions for implementing changes
assert any(
keyword in message_content.lower() for keyword in ["implement", "edit", "changes", "file"]
)
def test_call_with_empty_context():
"""Test the node with empty documentation context."""
state = IssueDocumentationState(
issue_title="Create docs",
issue_body="Create new documentation",
issue_comments=[],
max_refined_query_loop=3,
documentation_query="",
documentation_context=[],
issue_documentation_analyzer_messages=[AIMessage(content="Create new documentation files")],
edit_messages=[],
edit_patch="",
issue_response="",
)
node = IssueDocumentationEditMessageNode()
result = node(state)
assert "edit_messages" in result
assert len(result["edit_messages"]) == 1
def test_extracts_last_analyzer_message(basic_state):
"""Test that the node extracts the last message from analyzer history."""
# Add multiple messages to analyzer history
basic_state["issue_documentation_analyzer_messages"] = [
AIMessage(content="First message"),
AIMessage(content="Second message"),
AIMessage(content="Final plan: Update docs"),
]
node = IssueDocumentationEditMessageNode()
result = node(basic_state)
message_content = result["edit_messages"][0].content
# Should contain the final plan
assert "Final plan" in message_content

View File

@@ -0,0 +1,96 @@
import pytest
from langchain_core.messages import AIMessage
from prometheus.lang_graph.nodes.issue_documentation_responder_node import (
IssueDocumentationResponderNode,
)
from prometheus.lang_graph.subgraphs.issue_documentation_state import IssueDocumentationState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(
responses=[
"The documentation has been successfully updated. "
"I've added new API endpoint documentation and included examples as requested."
]
)
@pytest.fixture
def basic_state():
return IssueDocumentationState(
issue_title="Update API documentation",
issue_body="The API documentation needs to be updated with the new endpoints",
issue_comments=[
{"username": "user1", "comment": "Please include examples"},
],
max_refined_query_loop=3,
documentation_query="Find API documentation",
documentation_context=[],
issue_documentation_analyzer_messages=[
AIMessage(content="Plan: Update README.md with new API endpoints and add examples")
],
edit_messages=[],
edit_patch="diff --git a/README.md b/README.md\n+New API documentation",
issue_response="",
)
def test_init_issue_documentation_responder_node(fake_llm):
"""Test IssueDocumentationResponderNode initialization."""
node = IssueDocumentationResponderNode(fake_llm)
assert node.model is not None
def test_call_method_basic(fake_llm, basic_state):
"""Test basic call functionality."""
node = IssueDocumentationResponderNode(fake_llm)
result = node(basic_state)
assert "issue_response" in result
assert "successfully updated" in result["issue_response"]
assert len(result["issue_response"]) > 0
def test_call_method_with_patch(fake_llm, basic_state):
"""Test response generation with patch."""
node = IssueDocumentationResponderNode(fake_llm)
result = node(basic_state)
assert "issue_response" in result
assert isinstance(result["issue_response"], str)
def test_call_method_without_patch(fake_llm):
"""Test response generation without patch."""
state = IssueDocumentationState(
issue_title="Update docs",
issue_body="Please update the documentation",
issue_comments=[],
max_refined_query_loop=3,
documentation_query="",
documentation_context=[],
issue_documentation_analyzer_messages=[AIMessage(content="Documentation plan created")],
edit_messages=[],
edit_patch="",
issue_response="",
)
node = IssueDocumentationResponderNode(fake_llm)
result = node(state)
assert "issue_response" in result
assert len(result["issue_response"]) > 0
def test_response_includes_issue_details(fake_llm, basic_state):
"""Test that the generated response is relevant to the issue."""
node = IssueDocumentationResponderNode(fake_llm)
result = node(basic_state)
assert "issue_response" in result
# The response should be a string with meaningful content
assert len(result["issue_response"]) > 10

View File

@@ -0,0 +1,89 @@
import pytest
from prometheus.lang_graph.nodes.issue_feature_responder_node import IssueFeatureResponderNode
from prometheus.lang_graph.subgraphs.issue_feature_state import IssueFeatureState
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def fake_llm():
return FakeListChatWithToolsModel(
responses=[
"Thank you for requesting this feature. The implementation has been completed and is ready for review."
]
)
@pytest.fixture
def basic_state():
return IssueFeatureState(
issue_title="Add dark mode support",
issue_body="Please add dark mode to the application",
issue_comments=[
{"username": "user1", "comment": "This would be great!"},
{"username": "user2", "comment": "I need this feature"},
],
final_patch="Added dark mode theme switching functionality",
run_regression_test=True,
number_of_candidate_patch=3,
selected_regression_tests=["tests:tests"],
issue_response="Mock Response",
)
def test_format_human_message_basic(fake_llm, basic_state):
"""Test basic human message formatting."""
node = IssueFeatureResponderNode(fake_llm)
message = node.format_human_message(basic_state)
assert "Add dark mode support" in message.content
assert "Please add dark mode to the application" in message.content
assert "user1" in message.content
assert "user2" in message.content
assert "Added dark mode theme switching functionality" in message.content
def test_format_human_message_with_regression_tests(fake_llm, basic_state):
"""Test message formatting with regression tests."""
# Add tested_patch_result to simulate passed tests
from prometheus.models.test_patch_result import TestedPatchResult
basic_state["tested_patch_result"] = [
TestedPatchResult(patch="test patch", passed=True, regression_test_failure_log="")
]
node = IssueFeatureResponderNode(fake_llm)
message = node.format_human_message(basic_state)
assert "✓ All selected regression tests passed successfully" in message.content
def test_format_human_message_no_tests(fake_llm):
"""Test message formatting without tests."""
state = IssueFeatureState(
issue_title="Add feature",
issue_body="Feature description",
issue_comments=[],
final_patch="Implementation patch",
run_regression_test=False,
number_of_candidate_patch=1,
selected_regression_tests=[],
issue_response="",
)
node = IssueFeatureResponderNode(fake_llm)
message = node.format_human_message(state)
assert "No automated tests were run for this feature implementation." in message.content
def test_call_method(fake_llm, basic_state):
"""Test the call method execution."""
node = IssueFeatureResponderNode(fake_llm)
result = node(basic_state)
assert "issue_response" in result
assert (
result["issue_response"]
== "Thank you for requesting this feature. The implementation has been completed and is ready for review."
)

View File

@@ -0,0 +1,20 @@
from langchain_core.messages import HumanMessage
from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode
def test_reset_messages_node():
reset_build_messages_node = ResetMessagesNode("build_messages")
state = {
"build_messages": [HumanMessage(content="message 1"), HumanMessage(content="message 2")],
"test_messages": [HumanMessage(content="message 3"), HumanMessage(content="message 4")],
}
reset_build_messages_node(state)
assert "build_messages" in state
assert len(state["build_messages"]) == 0
assert "test_messages" in state
assert len(state["test_messages"]) == 2

View File

@@ -0,0 +1,24 @@
from pathlib import Path
from unittest.mock import Mock
from prometheus.docker.general_container import GeneralContainer
from prometheus.git.git_repository import GitRepository
from prometheus.lang_graph.nodes.update_container_node import UpdateContainerNode
def test_update_container_node():
mocked_container = Mock(spec=GeneralContainer)
mocked_container.is_running.return_value = True
mocked_git_repo = Mock(spec=GitRepository)
mocked_git_repo.get_diff.return_value = "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content"
mocked_git_repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo")
update_container_node = UpdateContainerNode(mocked_container, mocked_git_repo)
update_container_node(None)
assert mocked_git_repo.get_diff.call_count == 1
assert mocked_container.is_running.call_count == 1
assert mocked_container.update_files.call_count == 1
mocked_container.update_files.assert_called_with(
Path("/test/working/dir/repositories/repo"), [Path("newfile")], []
)

View File

@@ -0,0 +1,32 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import ToolMessage
from prometheus.docker.base_container import BaseContainer
from prometheus.lang_graph.nodes.user_defined_build_node import UserDefinedBuildNode
@pytest.fixture
def mock_container():
container = Mock(spec=BaseContainer)
container.run_build.return_value = "Build successful"
return container
@pytest.fixture
def build_node(mock_container):
return UserDefinedBuildNode(container=mock_container)
def test_successful_build(build_node, mock_container):
expected_output = "Build successful"
result = build_node(None)
assert isinstance(result, dict)
assert "build_messages" in result
assert len(result["build_messages"]) == 1
assert isinstance(result["build_messages"][0], ToolMessage)
assert result["build_messages"][0].content == expected_output
mock_container.run_build.assert_called_once()

View File

@@ -0,0 +1,32 @@
from unittest.mock import Mock
import pytest
from langchain_core.messages import ToolMessage
from prometheus.docker.base_container import BaseContainer
from prometheus.lang_graph.nodes.user_defined_test_node import UserDefinedTestNode
@pytest.fixture
def mock_container():
container = Mock(spec=BaseContainer)
container.run_test.return_value = "Test successful"
return container
@pytest.fixture
def test_node(mock_container):
return UserDefinedTestNode(container=mock_container)
def test_successful_test(test_node, mock_container):
expected_output = "Test successful"
result = test_node(None)
assert isinstance(result, dict)
assert "test_messages" in result
assert len(result["test_messages"]) == 1
assert isinstance(result["test_messages"][0], ToolMessage)
assert result["test_messages"][0].content == expected_output
mock_container.run_test.assert_called_once()

View File

@@ -0,0 +1,38 @@
from unittest.mock import Mock
import pytest
from langgraph.checkpoint.base import BaseCheckpointSaver
from prometheus.docker.base_container import BaseContainer
from prometheus.git.git_repository import GitRepository
from prometheus.lang_graph.subgraphs.bug_fix_verification_subgraph import BugFixVerificationSubgraph
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def mock_checkpointer():
return Mock(spec=BaseCheckpointSaver)
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.playground_path = "mock/playground/path"
return git_repo
def test_bug_fix_verification_subgraph_basic_initialization(
mock_container,
mock_git_repo,
):
"""Test that BugFixVerificationSubgraph initializes correctly with basic components."""
fake_model = FakeListChatWithToolsModel(responses=[])
subgraph = BugFixVerificationSubgraph(fake_model, mock_container, mock_git_repo)
assert subgraph.subgraph is not None

View File

@@ -0,0 +1,49 @@
from unittest.mock import Mock
import pytest
from prometheus.docker.base_container import BaseContainer
from prometheus.git.git_repository import GitRepository
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.subgraphs.bug_reproduction_subgraph import BugReproductionSubgraph
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
kg.root_node_id = 0
return kg
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.playground_path = "mock/playground/path"
return git_repo
def test_bug_reproduction_subgraph_basic_initialization(mock_container, mock_kg, mock_git_repo):
"""Test that BugReproductionSubgraph initializes correctly with basic components."""
# Initialize fake model with empty responses
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
fake_base_model = FakeListChatWithToolsModel(responses=[])
# Initialize the subgraph
subgraph = BugReproductionSubgraph(
fake_advanced_model,
fake_base_model,
mock_container,
mock_kg,
mock_git_repo,
None,
)
# Verify the subgraph was created
assert subgraph.subgraph is not None

View File

@@ -0,0 +1,47 @@
from unittest.mock import Mock
import pytest
from prometheus.docker.base_container import BaseContainer
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.subgraphs.build_and_test_subgraph import BuildAndTestSubgraph
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def mock_kg():
return Mock(spec=KnowledgeGraph)
def test_build_and_test_subgraph_basic_initialization(mock_container, mock_kg):
"""Test that BuildAndTestSubgraph initializes correctly with basic components."""
# Initialize fake model with empty responses
fake_model = FakeListChatWithToolsModel(responses=[])
# Initialize the subgraph
subgraph = BuildAndTestSubgraph(container=mock_container, model=fake_model, kg=mock_kg)
# Verify the subgraph was created
assert subgraph.subgraph is not None
def test_build_and_test_subgraph_with_commands(mock_container, mock_kg):
"""Test that BuildAndTestSubgraph initializes correctly with build and test commands."""
fake_model = FakeListChatWithToolsModel(responses=[])
build_commands = ["make build"]
test_commands = ["make test"]
subgraph = BuildAndTestSubgraph(
container=mock_container,
model=fake_model,
kg=mock_kg,
build_commands=build_commands,
test_commands=test_commands,
)
assert subgraph.subgraph is not None

View File

@@ -0,0 +1,69 @@
from unittest.mock import Mock
import pytest
from prometheus.docker.base_container import BaseContainer
from prometheus.git.git_repository import GitRepository
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.subgraphs.issue_bug_subgraph import IssueBugSubgraph
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
# Configure the mock to return a list of AST node types
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
kg.root_node_id = 0
return kg
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.playground_path = "mock/playground/path"
return git_repo
def test_issue_bug_subgraph_basic_initialization(mock_container, mock_kg, mock_git_repo):
"""Test that IssueBugSubgraph initializes correctly with basic components."""
# Initialize fake model with empty responses
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
fake_base_model = FakeListChatWithToolsModel(responses=[])
# Initialize the subgraph with required parameters
subgraph = IssueBugSubgraph(
advanced_model=fake_advanced_model,
base_model=fake_base_model,
container=mock_container,
kg=mock_kg,
git_repo=mock_git_repo,
repository_id=1,
)
# Verify the subgraph was created
assert subgraph.subgraph is not None
def test_issue_bug_subgraph_with_commands(mock_container, mock_kg, mock_git_repo):
"""Test that IssueBugSubgraph initializes correctly with build and test commands."""
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
fake_base_model = FakeListChatWithToolsModel(responses=[])
test_commands = ["make test"]
subgraph = IssueBugSubgraph(
advanced_model=fake_advanced_model,
base_model=fake_base_model,
container=mock_container,
kg=mock_kg,
git_repo=mock_git_repo,
repository_id=1,
test_commands=test_commands,
)
assert subgraph.subgraph is not None

View File

@@ -0,0 +1,45 @@
from unittest.mock import Mock
import pytest
from prometheus.git.git_repository import GitRepository
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.subgraphs.issue_classification_subgraph import (
IssueClassificationSubgraph,
)
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
# Configure the mock to return a list of AST node types
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
kg.root_node_id = 0
return kg
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.playground_path = "mock/playground/path"
return git_repo
def test_issue_classification_subgraph_basic_initialization(mock_kg, mock_git_repo):
"""Test that IssueClassificationSubgraph initializes correctly with basic components."""
# Initialize fake model with empty responses
fake_model = FakeListChatWithToolsModel(responses=[])
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
# Initialize the subgraph with required parameters
subgraph = IssueClassificationSubgraph(
advanced_model=fake_advanced_model,
model=fake_model,
kg=mock_kg,
local_path=mock_git_repo.playground_path,
repository_id=1,
)
# Verify the subgraph was created
assert subgraph.subgraph is not None

View File

@@ -0,0 +1,51 @@
from unittest.mock import Mock
import pytest
from prometheus.git.git_repository import GitRepository
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.subgraphs.issue_documentation_subgraph import (
IssueDocumentationSubgraph,
)
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
# Configure the mock to return a list of AST node types
kg.get_all_ast_node_types.return_value = [
"FunctionDef",
"ClassDef",
"Module",
"Import",
"Call",
]
kg.root_node_id = 0
return kg
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.playground_path = "mock/playground/path"
return git_repo
def test_issue_documentation_subgraph_basic_initialization(mock_kg, mock_git_repo):
"""Test that IssueDocumentationSubgraph initializes correctly with basic components."""
# Initialize fake model with empty responses
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
fake_base_model = FakeListChatWithToolsModel(responses=[])
# Initialize the subgraph with required parameters
subgraph = IssueDocumentationSubgraph(
advanced_model=fake_advanced_model,
base_model=fake_base_model,
kg=mock_kg,
git_repo=mock_git_repo,
repository_id=1,
)
# Verify the subgraph was created
assert subgraph.subgraph is not None

View File

@@ -0,0 +1,49 @@
from unittest.mock import Mock
import pytest
from prometheus.docker.base_container import BaseContainer
from prometheus.git.git_repository import GitRepository
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.lang_graph.subgraphs.issue_question_subgraph import IssueQuestionSubgraph
from tests.test_utils.util import FakeListChatWithToolsModel
@pytest.fixture
def mock_container():
return Mock(spec=BaseContainer)
@pytest.fixture
def mock_kg():
kg = Mock(spec=KnowledgeGraph)
# Configure the mock to return a list of AST node types
kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"]
kg.root_node_id = 0
return kg
@pytest.fixture
def mock_git_repo():
git_repo = Mock(spec=GitRepository)
git_repo.playground_path = "mock/playground/path"
return git_repo
def test_issue_question_subgraph_basic_initialization(mock_container, mock_kg, mock_git_repo):
"""Test that IssueQuestionSubgraph initializes correctly with basic components."""
# Initialize fake model with empty responses
fake_advanced_model = FakeListChatWithToolsModel(responses=[])
fake_base_model = FakeListChatWithToolsModel(responses=[])
# Initialize the subgraph with required parameters
subgraph = IssueQuestionSubgraph(
advanced_model=fake_advanced_model,
base_model=fake_base_model,
kg=mock_kg,
git_repo=mock_git_repo,
repository_id=1,
)
# Verify the subgraph was created
assert subgraph.subgraph is not None

View File

View File

@@ -0,0 +1,131 @@
import pytest
from prometheus.app.services.neo4j_service import Neo4jService
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
from tests.test_utils import test_project_paths
from tests.test_utils.fixtures import ( # noqa: F401
NEO4J_PASSWORD,
NEO4J_USERNAME,
empty_neo4j_container_fixture,
neo4j_container_with_kg_fixture,
)
@pytest.fixture
async def mock_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
neo4j_container, kg = neo4j_container_with_kg_fixture
service = Neo4jService(neo4j_container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
service.start()
yield service
await service.close()
@pytest.fixture
async def mock_empty_neo4j_service(empty_neo4j_container_fixture): # noqa: F811
"""Fixture: provide a clean DatabaseService using the Postgres test container."""
neo4j_container = empty_neo4j_container_fixture
service = Neo4jService(neo4j_container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
service.start()
yield service
await service.close()
@pytest.mark.slow
async def test_num_ast_nodes(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_ast_nodes = await session.execute_read(handler._read_ast_nodes, root_node_id=0)
assert len(read_ast_nodes) == 7
@pytest.mark.slow
async def test_num_file_nodes(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_file_nodes = await session.execute_read(handler._read_file_nodes, root_node_id=0)
assert len(read_file_nodes) == 7
@pytest.mark.slow
async def test_num_text_nodes(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_text_nodes = await session.execute_read(handler._read_text_nodes, root_node_id=0)
assert len(read_text_nodes) == 1
@pytest.mark.slow
async def test_num_parent_of_edges(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_parent_of_edges = await session.execute_read(
handler._read_parent_of_edges, root_node_id=0
)
assert len(read_parent_of_edges) == 4
@pytest.mark.slow
async def test_num_has_file_edges(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_has_file_edges = await session.execute_read(
handler._read_has_file_edges, root_node_id=0
)
assert len(read_has_file_edges) == 6
@pytest.mark.slow
async def test_num_has_ast_edges(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_has_ast_edges = await session.execute_read(handler._read_has_ast_edges, root_node_id=0)
assert len(read_has_ast_edges) == 3
@pytest.mark.slow
async def test_num_has_text_edges(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_has_text_edges = await session.execute_read(
handler._read_has_text_edges, root_node_id=0
)
assert len(read_has_text_edges) == 1
@pytest.mark.slow
async def test_num_next_chunk_edges(mock_neo4j_service):
handler = KnowledgeGraphHandler(mock_neo4j_service.neo4j_driver, 100)
async with mock_neo4j_service.neo4j_driver.session() as session:
read_next_chunk_edges = await session.execute_read(
handler._read_next_chunk_edges, root_node_id=0
)
assert len(read_next_chunk_edges) == 0
@pytest.mark.slow
async def test_clear_knowledge_graph(mock_empty_neo4j_service):
kg = KnowledgeGraph(1, 1000, 100, 0)
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
driver = mock_empty_neo4j_service.neo4j_driver
handler = KnowledgeGraphHandler(driver, 100)
await handler.write_knowledge_graph(kg)
await handler.clear_knowledge_graph(0)
# Verify that the graph is cleared
async with driver.session() as session:
result = await session.run("MATCH (n) RETURN COUNT(n) AS node_count")
record = await result.single()
node_count = record["node_count"]
assert node_count == 0

View File

View File

@@ -0,0 +1,34 @@
from pathlib import Path
import pytest
from prometheus.parser.file_types import FileType
@pytest.mark.parametrize(
"file_path,expected_type",
[
("script.sh", FileType.BASH),
("program.c", FileType.C),
("app.cs", FileType.CSHARP),
("class.cpp", FileType.CPP),
("header.cc", FileType.CPP),
("file.cxx", FileType.CPP),
("main.go", FileType.GO),
("Class.java", FileType.JAVA),
("app.js", FileType.JAVASCRIPT),
("Service.kt", FileType.KOTLIN),
("index.php", FileType.PHP),
("script.py", FileType.PYTHON),
("query.sql", FileType.SQL),
("config.yaml", FileType.YAML),
("docker-compose.yml", FileType.YAML),
("readme.md", FileType.UNKNOWN),
("Makefile", FileType.UNKNOWN),
("", FileType.UNKNOWN),
],
)
def test_file_type_from_path(file_path: str, expected_type: FileType):
"""Test that file extensions are correctly mapped to FileTypes"""
path = Path(file_path)
assert FileType.from_path(path) == expected_type

View File

@@ -0,0 +1,68 @@
from pathlib import Path
from unittest.mock import MagicMock, create_autospec, mock_open, patch
import pytest
from tree_sitter._binding import Tree
from prometheus.parser.file_types import FileType
from prometheus.parser.tree_sitter_parser import (
FILE_TYPE_TO_LANG,
parse,
supports_file,
)
# Test fixtures
@pytest.fixture
def mock_python_file():
return Path("test.py")
@pytest.fixture
def mock_tree():
tree = create_autospec(Tree, instance=True)
return tree
# Test supports_file function
def test_supports_file_with_supported_type(mock_python_file):
with patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path:
mock_from_path.return_value = FileType.PYTHON
assert supports_file(mock_python_file) is True
mock_from_path.assert_called_once_with(mock_python_file)
def test_parse_python_file_successfully(mock_python_file, mock_tree):
mock_content = b'print("hello")'
m = mock_open(read_data=mock_content)
with (
patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path,
patch("prometheus.parser.tree_sitter_parser.get_parser") as mock_get_parser,
patch.object(Path, "open", m),
):
# Setup mocks
mock_from_path.return_value = FileType.PYTHON
mock_parser = mock_get_parser.return_value
mock_parser.parse.return_value = mock_tree
# Test parse function
result = parse(mock_python_file)
# Verify results and interactions
assert result == mock_tree
mock_from_path.assert_called_once_with(mock_python_file)
mock_get_parser.assert_called_once_with("python")
mock_parser.parse.assert_called_once_with(mock_content)
def test_parse_all_supported_languages():
"""Test that we can get parsers for all supported languages."""
with patch("tree_sitter_language_pack.get_parser") as mock_get_parser:
mock_parser = MagicMock()
mock_get_parser.return_value = mock_parser
for lang in FILE_TYPE_TO_LANG.values():
mock_get_parser(lang)
mock_get_parser.assert_called_with(lang)
mock_get_parser.reset_mock()

View File

@@ -0,0 +1 @@
.dummy_git/

View File

@@ -0,0 +1,5 @@
public class test {
public static void main(String[] args) {
System.out.println("Hello world!");
}
}

View File

@@ -0,0 +1 @@
print("Hello world!")

View File

@@ -0,0 +1,11 @@
@program Start
use library console_output
start {
declare text message = "Hello world!";
print_to_console(message);
halt;
}
end @program

View File

@@ -0,0 +1,15 @@
# A
Text under header A.
## B
Text under header B.
## C
Text under header C.
### D
Text under header D.

View File

@@ -0,0 +1,6 @@
#include <stdio.h>
int main() {
printf("Hello world!");
return 0;
}

View File

View File

@@ -0,0 +1,90 @@
import shutil
import tempfile
import uuid
from pathlib import Path
import pytest
from git import Repo
from testcontainers.neo4j import Neo4jContainer
from testcontainers.postgres import PostgresContainer
from prometheus.app.services.neo4j_service import Neo4jService
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler
from tests.test_utils import test_project_paths
NEO4J_IMAGE = "neo4j:5.20.0"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "password"
POSTGRES_IMAGE = "postgres"
POSTGRES_USERNAME = "postgres"
POSTGRES_PASSWORD = "password"
POSTGRES_DB = "postgres"
@pytest.fixture(scope="session")
async def neo4j_container_with_kg_fixture():
kg = KnowledgeGraph(1, 1000, 100, 0)
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
container = (
Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)
.with_env("NEO4J_PLUGINS", '["apoc"]')
.with_name(f"neo4j_container_with_kg_{uuid.uuid4().hex[:12]}")
)
with container as neo4j_container:
neo4j_service = Neo4jService(container.get_connection_url(), NEO4J_USERNAME, NEO4J_PASSWORD)
handler = KnowledgeGraphHandler(neo4j_service.neo4j_driver, 100)
await handler.write_knowledge_graph(kg)
yield neo4j_container, kg
await neo4j_service.close()
@pytest.fixture(scope="function")
def empty_neo4j_container_fixture():
container = (
Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)
.with_env("NEO4J_PLUGINS", '["apoc"]')
.with_name(f"empty_neo4j_container_{uuid.uuid4().hex[:12]}")
)
with container as neo4j_container:
yield neo4j_container
@pytest.fixture(scope="session")
def postgres_container_fixture():
container = PostgresContainer(
image=POSTGRES_IMAGE,
username=POSTGRES_USERNAME,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DB,
port=5432,
driver="asyncpg",
).with_name(f"postgres_container_{uuid.uuid4().hex[:12]}")
with container as postgres_container:
yield postgres_container
@pytest.fixture(scope="function")
def git_repo_fixture():
temp_dir = Path(tempfile.mkdtemp())
temp_project_dir = temp_dir / "test_project"
original_project_path = test_project_paths.TEST_PROJECT_PATH
try:
shutil.copytree(original_project_path, temp_project_dir)
shutil.move(temp_project_dir / test_project_paths.GIT_DIR.name, temp_project_dir / ".git")
repo = Repo(temp_project_dir)
yield repo
finally:
shutil.rmtree(temp_project_dir)
@pytest.fixture
def temp_test_dir(tmp_path):
"""Create a temporary test directory."""
test_dir = tmp_path / "test_files"
test_dir.mkdir()
yield test_dir
# Cleanup happens automatically after tests due to tmp_path fixture

View File

@@ -0,0 +1,11 @@
from pathlib import Path
TEST_PROJECT_PATH = Path(__file__).parent.parent / "test_project"
GIT_DIR = TEST_PROJECT_PATH / ".dummy_git"
C_FILE = TEST_PROJECT_PATH / "test.c"
BAR_DIR = TEST_PROJECT_PATH / "bar"
JAVA_FILE = BAR_DIR / "test.java"
PYTHON_FILE = BAR_DIR / "test.py"
FOO_DIR = TEST_PROJECT_PATH / "foo"
MD_FILE = FOO_DIR / "test.md"
DUMMY_FILE = FOO_DIR / "test.dummy"

View File

@@ -0,0 +1,13 @@
from langchain_core.language_models.fake_chat_models import FakeListChatModel
from testcontainers.neo4j import Neo4jContainer
class FakeListChatWithToolsModel(FakeListChatModel):
def bind_tools(self, tools=None, tool_choice=None, **kwargs):
return self
def clean_neo4j_container(neo4j_container: Neo4jContainer):
with neo4j_container.get_driver() as driver:
with driver.session() as session:
session.run("MATCH (n) DETACH DELETE n")

View File

View File

@@ -0,0 +1,141 @@
import shutil
import pytest
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.tools.file_operation import FileOperationTool
from tests.test_utils import test_project_paths
from tests.test_utils.fixtures import temp_test_dir # noqa: F401
@pytest.fixture(scope="function")
async def knowledge_graph_fixture(temp_test_dir): # noqa: F811
if temp_test_dir.exists():
shutil.rmtree(temp_test_dir)
shutil.copytree(test_project_paths.TEST_PROJECT_PATH, temp_test_dir)
kg = KnowledgeGraph(1, 1000, 100, 0)
await kg.build_graph(temp_test_dir)
return kg
@pytest.fixture(scope="function")
def file_operation_tool(temp_test_dir, knowledge_graph_fixture): # noqa: F811
file_operation = FileOperationTool(temp_test_dir, knowledge_graph_fixture)
return file_operation
def test_read_file_with_knowledge_graph_data(file_operation_tool):
relative_path = str(
test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
)
result = file_operation_tool.read_file_with_knowledge_graph_data(relative_path)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "preview" in result_row
assert 'print("Hello world!")' in result_row["preview"].get("text", "")
assert "FileNode" in result_row
assert result_row["FileNode"].get("relative_path", "") == relative_path
def test_create_and_read_file(temp_test_dir, file_operation_tool): # noqa: F811
"""Test creating a file and reading its contents."""
test_file = temp_test_dir / "test.txt"
content = "line 1\nline 2\nline 3"
# Test create_file
result = file_operation_tool.create_file("test.txt", content)
assert test_file.exists()
assert test_file.read_text() == content
assert result == "The file test.txt has been created."
# Test read_file
result = file_operation_tool.read_file("test.txt")
expected = "1. line 1\n2. line 2\n3. line 3"
assert result == expected
def test_read_file_nonexistent(file_operation_tool):
"""Test reading a nonexistent file."""
result = file_operation_tool.read_file("nonexistent_file.txt")
assert result == "The file nonexistent_file.txt does not exist."
def test_read_file_with_line_numbers(file_operation_tool):
"""Test reading specific line ranges from a file."""
content = "line 1\nline 2\nline 3\nline 4\nline 5"
file_operation_tool.create_file("test_lines.txt", content)
# Test reading specific lines
result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 2, 4)
expected = "2. line 2\n3. line 3"
assert result == expected
# Test invalid range
result = file_operation_tool.read_file_with_line_numbers("test_lines.txt", 4, 2)
assert result == "The end line number 2 must be greater than the start line number 4."
def test_delete(file_operation_tool, temp_test_dir): # noqa: F811
"""Test file and directory deletion."""
# Test file deletion
test_file = temp_test_dir / "to_delete.txt"
file_operation_tool.create_file("to_delete.txt", "content")
assert test_file.exists()
result = file_operation_tool.delete("to_delete.txt")
assert result == "The file to_delete.txt has been deleted."
assert not test_file.exists()
# Test directory deletion
test_subdir = temp_test_dir / "subdir"
test_subdir.mkdir()
file_operation_tool.create_file("subdir/file.txt", "content")
result = file_operation_tool.delete("subdir")
assert result == "The directory subdir has been deleted."
assert not test_subdir.exists()
def test_delete_nonexistent(file_operation_tool):
"""Test deleting a nonexistent path."""
result = file_operation_tool.delete("nonexistent_path")
assert result == "The file nonexistent_path does not exist."
def test_edit_file(file_operation_tool):
"""Test editing specific lines in a file."""
# Test case 1: Successfully edit a single occurrence
initial_content = "line 1\nline 2\nline 3\nline 4\nline 5"
file_operation_tool.create_file("edit_test.txt", initial_content)
result = file_operation_tool.edit_file("edit_test.txt", "line 2", "new line 2")
assert result == "Successfully edited edit_test.txt."
# Test case 2: Absolute path error
result = file_operation_tool.edit_file("/edit_test.txt", "line 2", "new line 2")
assert result == "relative_path: /edit_test.txt is a absolute path, not relative path."
# Test case 3: File doesn't exist
result = file_operation_tool.edit_file("nonexistent.txt", "line 2", "new line 2")
assert result == "The file nonexistent.txt does not exist."
# Test case 4: No matches found
result = file_operation_tool.edit_file("edit_test.txt", "nonexistent line", "new content")
assert (
result
== "No match found for the specified content in edit_test.txt. Please verify the content to replace."
)
# Test case 5: Multiple occurrences
duplicate_content = "line 1\nline 2\nline 2\nline 3"
file_operation_tool.create_file("duplicate_test.txt", duplicate_content)
result = file_operation_tool.edit_file("duplicate_test.txt", "line 2", "new line 2")
assert (
result
== "Found 2 occurrences of the specified content in duplicate_test.txt. Please provide more context to ensure a unique match."
)
def test_create_file_already_exists(file_operation_tool):
"""Test creating a file that already exists."""
file_operation_tool.create_file("existing.txt", "content")
result = file_operation_tool.create_file("existing.txt", "new content")
assert result == "The file existing.txt already exists."

View File

@@ -0,0 +1,189 @@
import pytest
from prometheus.graph.knowledge_graph import KnowledgeGraph
from prometheus.tools.graph_traversal import GraphTraversalTool
from tests.test_utils import test_project_paths
@pytest.fixture(scope="function")
async def knowledge_graph_fixture():
kg = KnowledgeGraph(1000, 100, 10, 0)
await kg.build_graph(test_project_paths.TEST_PROJECT_PATH)
return kg
@pytest.fixture(scope="function")
def graph_traversal_tool(knowledge_graph_fixture):
graph_traversal_tool = GraphTraversalTool(knowledge_graph_fixture)
return graph_traversal_tool
@pytest.mark.slow
async def test_find_file_node_with_basename(graph_traversal_tool):
result = graph_traversal_tool.find_file_node_with_basename(test_project_paths.PYTHON_FILE.name)
basename = test_project_paths.PYTHON_FILE.name
relative_path = str(
test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
)
result_data = result[1]
assert len(result_data) == 1
assert "FileNode" in result_data[0]
assert result_data[0]["FileNode"].get("basename", "") == basename
assert result_data[0]["FileNode"].get("relative_path", "") == relative_path
@pytest.mark.slow
async def test_find_file_node_with_relative_path(graph_traversal_tool):
relative_path = str(
test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
)
result = graph_traversal_tool.find_file_node_with_relative_path(relative_path)
basename = test_project_paths.MD_FILE.name
result_data = result[1]
assert len(result_data) == 1
assert "FileNode" in result_data[0]
assert result_data[0]["FileNode"].get("basename", "") == basename
assert result_data[0]["FileNode"].get("relative_path", "") == relative_path
@pytest.mark.slow
async def test_find_ast_node_with_text_in_file_with_basename(graph_traversal_tool):
basename = test_project_paths.PYTHON_FILE.name
result = graph_traversal_tool.find_ast_node_with_text_in_file_with_basename(
"Hello world!", basename
)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "ASTNode" in result_row
assert "Hello world!" in result_row["ASTNode"].get("text", "")
assert "FileNode" in result_row
assert result_row["FileNode"].get("basename", "") == basename
@pytest.mark.slow
async def test_find_ast_node_with_text_in_file_with_relative_path(graph_traversal_tool):
relative_path = str(
test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
)
result = graph_traversal_tool.find_ast_node_with_text_in_file_with_relative_path(
"Hello world!", relative_path
)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "ASTNode" in result_row
assert "Hello world!" in result_row["ASTNode"].get("text", "")
assert "FileNode" in result_row
assert result_row["FileNode"].get("relative_path", "") == relative_path
@pytest.mark.slow
async def test_find_ast_node_with_type_in_file_with_basename(graph_traversal_tool):
basename = test_project_paths.C_FILE.name
node_type = "function_definition"
result = graph_traversal_tool.find_ast_node_with_type_in_file_with_basename(node_type, basename)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "ASTNode" in result_row
assert result_row["ASTNode"].get("type", "") == node_type
assert "FileNode" in result_row
assert result_row["FileNode"].get("basename", "") == basename
@pytest.mark.slow
async def test_find_ast_node_with_type_in_file_with_relative_path(graph_traversal_tool):
relative_path = str(
test_project_paths.JAVA_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
)
node_type = "string_literal"
result = graph_traversal_tool.find_ast_node_with_type_in_file_with_relative_path(
node_type, relative_path
)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "ASTNode" in result_row
assert result_row["ASTNode"].get("type", "") == node_type
assert "FileNode" in result_row
assert result_row["FileNode"].get("relative_path", "") == relative_path
@pytest.mark.slow
async def test_find_text_node_with_text(graph_traversal_tool):
text = "Text under header C"
result = graph_traversal_tool.find_text_node_with_text(text)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "TextNode" in result_row
assert text in result_row["TextNode"].get("text", "")
assert "start_line" in result_row["TextNode"]
assert result_row["TextNode"]["start_line"] == 1
assert "end_line" in result_row["TextNode"]
assert result_row["TextNode"]["end_line"] == 13
assert "FileNode" in result_row
assert result_row["FileNode"].get("relative_path", "") == "foo/test.md"
@pytest.mark.slow
async def test_find_text_node_with_text_in_file(graph_traversal_tool):
basename = test_project_paths.MD_FILE.name
text = "Text under header B"
result = graph_traversal_tool.find_text_node_with_text_in_file(text, basename)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "TextNode" in result_row
assert text in result_row["TextNode"].get("text", "")
assert "start_line" in result_row["TextNode"]
assert result_row["TextNode"]["start_line"] == 1
assert "end_line" in result_row["TextNode"]
assert result_row["TextNode"]["end_line"] == 13
assert "FileNode" in result_row
assert result_row["FileNode"].get("basename", "") == basename
@pytest.mark.slow
async def test_get_next_text_node_with_node_id(graph_traversal_tool):
node_id = 34
result = graph_traversal_tool.get_next_text_node_with_node_id(node_id)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "TextNode" in result_row
assert "Text under header D" in result_row["TextNode"].get("text", "")
assert "start_line" in result_row["TextNode"]
assert result_row["TextNode"]["start_line"] == 13
assert "end_line" in result_row["TextNode"]
assert result_row["TextNode"]["end_line"] == 15
assert "FileNode" in result_row
assert result_row["FileNode"].get("relative_path", "") == "foo/test.md"
@pytest.mark.slow
async def test_read_code_with_relative_path(graph_traversal_tool):
relative_path = str(
test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix()
)
result = graph_traversal_tool.read_code_with_relative_path(relative_path, 5, 6)
result_data = result[1]
assert len(result_data) > 0
for result_row in result_data:
assert "SelectedLines" in result_row
assert "return 0;" in result_row["SelectedLines"].get("text", "")
assert "FileNode" in result_row
assert result_row["FileNode"].get("relative_path", "") == relative_path

View File

@@ -0,0 +1,551 @@
from unittest.mock import MagicMock, patch
import pytest
from tavily import InvalidAPIKeyError, UsageLimitExceededError
from prometheus.exceptions.web_search_tool_exception import WebSearchToolException
from prometheus.tools.web_search import WebSearchInput, WebSearchTool, format_results
class TestFormatResults:
"""Test suite for format_results function."""
def test_format_results_basic(self):
"""Test basic formatting of search results without answer."""
response = {
"results": [
{
"title": "How to fix Python import error",
"url": "https://stackoverflow.com/questions/12345",
"content": "This is how you fix import errors in Python...",
},
{
"title": "Python Documentation",
"url": "https://docs.python.org/3/",
"content": "Official Python documentation...",
},
]
}
result = format_results(response)
assert "Detailed Results:" in result
assert "How to fix Python import error" in result
assert "https://stackoverflow.com/questions/12345" in result
assert "This is how you fix import errors in Python..." in result
assert "Python Documentation" in result
assert "https://docs.python.org/3/" in result
def test_format_results_with_answer(self):
"""Test formatting with answer included."""
response = {
"answer": "To fix import errors, check your PYTHONPATH and ensure modules are installed.",
"results": [
{
"title": "Python Import Guide",
"url": "https://docs.python.org/import",
"content": "Guide to Python imports...",
},
],
}
result = format_results(response)
assert "Answer:" in result
assert "To fix import errors" in result
assert "Sources:" in result
assert "Python Import Guide" in result
assert "https://docs.python.org/import" in result
assert "Detailed Results:" in result
def test_format_results_with_published_date(self):
"""Test formatting with published date."""
response = {
"results": [
{
"title": "Article Title",
"url": "https://example.com/article",
"content": "Article content...",
"published_date": "2024-01-15",
},
]
}
result = format_results(response)
assert "Published: 2024-01-15" in result
def test_format_results_with_included_domains(self):
"""Test formatting with included domains filter."""
response = {
"included_domains": ["stackoverflow.com", "github.com"],
"results": [
{
"title": "Test Result",
"url": "https://stackoverflow.com/test",
"content": "Test content",
},
],
}
result = format_results(response)
assert "Search Filters:" in result
assert "Including domains: stackoverflow.com, github.com" in result
def test_format_results_with_excluded_domains(self):
"""Test formatting with excluded domains filter."""
response = {
"excluded_domains": ["pinterest.com", "reddit.com"],
"results": [
{
"title": "Test Result",
"url": "https://example.com/test",
"content": "Test content",
},
],
}
result = format_results(response)
assert "Search Filters:" in result
assert "Excluding domains: pinterest.com, reddit.com" in result
def test_format_results_with_both_domain_filters(self):
"""Test formatting with both included and excluded domains."""
response = {
"included_domains": ["stackoverflow.com"],
"excluded_domains": ["pinterest.com"],
"results": [
{
"title": "Test Result",
"url": "https://stackoverflow.com/test",
"content": "Test content",
},
],
}
result = format_results(response)
assert "Including domains: stackoverflow.com" in result
assert "Excluding domains: pinterest.com" in result
def test_format_results_empty(self):
"""Test formatting with no results."""
response = {"results": []}
result = format_results(response)
assert "Detailed Results:" in result
# Should not contain any result-specific content
assert "Title:" not in result
assert "URL:" not in result
def test_format_results_multiple_results(self):
"""Test formatting with multiple results."""
response = {
"results": [
{
"title": f"Result {i}",
"url": f"https://example.com/{i}",
"content": f"Content {i}",
}
for i in range(5)
]
}
result = format_results(response)
for i in range(5):
assert f"Result {i}" in result
assert f"https://example.com/{i}" in result
assert f"Content {i}" in result
def test_format_results_special_characters(self):
"""Test formatting with special characters in content."""
response = {
"results": [
{
"title": "Special chars: @#$%^&*()",
"url": "https://example.com/special",
"content": "Content with 中文 and emojis 🚀",
},
]
}
result = format_results(response)
assert "@#$%^&*()" in result
assert "中文" in result
assert "🚀" in result
class TestWebSearchInput:
"""Test suite for WebSearchInput model."""
def test_web_search_input_valid(self):
"""Test valid WebSearchInput creation."""
input_data = WebSearchInput(query="Python import error")
assert input_data.query == "Python import error"
def test_web_search_input_empty_query(self):
"""Test WebSearchInput with empty query."""
input_data = WebSearchInput(query="")
assert input_data.query == ""
def test_web_search_input_long_query(self):
"""Test WebSearchInput with long query."""
long_query = "A" * 1000
input_data = WebSearchInput(query=long_query)
assert input_data.query == long_query
class TestWebSearchTool:
"""Test suite for WebSearchTool class."""
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_init_with_api_key(self, mock_tavily_client, mock_settings):
"""Test WebSearchTool initialization with API key."""
mock_settings.TAVILY_API_KEY = "test_api_key"
tool = WebSearchTool()
mock_tavily_client.assert_called_once_with(api_key="test_api_key")
assert tool.tavily_client is not None
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_init_without_api_key(self, mock_tavily_client, mock_settings):
"""Test WebSearchTool initialization without API key."""
mock_settings.TAVILY_API_KEY = None
tool = WebSearchTool()
mock_tavily_client.assert_not_called()
assert tool.tavily_client is None
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_success(self, mock_tavily_client, mock_settings):
"""Test successful web search."""
mock_settings.TAVILY_API_KEY = "test_api_key"
# Setup mock client
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {
"answer": "Test answer",
"results": [
{
"title": "Test Result",
"url": "https://stackoverflow.com/test",
"content": "Test content",
}
],
}
mock_client_instance.search.return_value = mock_response
# Create tool and search
tool = WebSearchTool()
result = tool.web_search(query="Python import error")
# Verify search was called with correct parameters
mock_client_instance.search.assert_called_once()
call_kwargs = mock_client_instance.search.call_args[1]
assert call_kwargs["query"] == "Python import error"
assert call_kwargs["max_results"] == 5
assert call_kwargs["search_depth"] == "advanced"
assert call_kwargs["include_answer"] is True
assert "stackoverflow.com" in call_kwargs["include_domains"]
assert "github.com" in call_kwargs["include_domains"]
# Verify result formatting
assert "Test answer" in result
assert "Test Result" in result
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_with_custom_params(self, mock_tavily_client, mock_settings):
"""Test web search with custom parameters."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {
"results": [
{
"title": "Custom Result",
"url": "https://example.com/test",
"content": "Custom content",
}
]
}
mock_client_instance.search.return_value = mock_response
tool = WebSearchTool()
tool.web_search(
query="test query",
max_results=10,
include_domains=["custom-domain.com"],
exclude_domains=["excluded.com"],
)
# Verify custom parameters were used
call_kwargs = mock_client_instance.search.call_args[1]
assert call_kwargs["max_results"] == 10
assert call_kwargs["include_domains"] == ["custom-domain.com"]
assert call_kwargs["exclude_domains"] == ["excluded.com"]
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_default_domains(self, mock_tavily_client, mock_settings):
"""Test web search uses default domains when not specified."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {"results": []}
mock_client_instance.search.return_value = mock_response
tool = WebSearchTool()
tool.web_search(query="test")
call_kwargs = mock_client_instance.search.call_args[1]
default_domains = call_kwargs["include_domains"]
# Verify default domains are present
assert "stackoverflow.com" in default_domains
assert "github.com" in default_domains
assert "developer.mozilla.org" in default_domains
assert "learn.microsoft.com" in default_domains
assert "docs.python.org" in default_domains
assert "pydantic.dev" in default_domains
assert "pypi.org" in default_domains
assert "readthedocs.org" in default_domains
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_invalid_api_key(self, mock_tavily_client, mock_settings):
"""Test web search with invalid API key."""
mock_settings.TAVILY_API_KEY = "invalid_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_client_instance.search.side_effect = InvalidAPIKeyError("Invalid API key")
tool = WebSearchTool()
with pytest.raises(WebSearchToolException) as exc_info:
tool.web_search(query="test")
assert "Invalid Tavily API key" in str(exc_info.value)
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_usage_limit_exceeded(self, mock_tavily_client, mock_settings):
"""Test web search when usage limit is exceeded."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_client_instance.search.side_effect = UsageLimitExceededError("Limit exceeded")
tool = WebSearchTool()
with pytest.raises(WebSearchToolException) as exc_info:
tool.web_search(query="test")
assert "Usage limit exceeded" in str(exc_info.value)
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_generic_exception(self, mock_tavily_client, mock_settings):
"""Test web search with generic exception."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_client_instance.search.side_effect = Exception("Network error")
tool = WebSearchTool()
with pytest.raises(WebSearchToolException) as exc_info:
tool.web_search(query="test")
assert "An error occurred: Network error" in str(exc_info.value)
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_empty_query(self, mock_tavily_client, mock_settings):
"""Test web search with empty query."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {"results": []}
mock_client_instance.search.return_value = mock_response
tool = WebSearchTool()
result = tool.web_search(query="")
# Should still call search even with empty query
mock_client_instance.search.assert_called_once()
assert "Detailed Results:" in result
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_with_none_exclude_domains(self, mock_tavily_client, mock_settings):
"""Test web search with None exclude_domains."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {"results": []}
mock_client_instance.search.return_value = mock_response
tool = WebSearchTool()
tool.web_search(query="test", exclude_domains=None)
# Verify None is converted to empty list
call_kwargs = mock_client_instance.search.call_args[1]
assert call_kwargs["exclude_domains"] == []
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_complex_query(self, mock_tavily_client, mock_settings):
"""Test web search with complex query containing special characters."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {
"results": [
{
"title": "Result",
"url": "https://example.com",
"content": "Content",
}
]
}
mock_client_instance.search.return_value = mock_response
tool = WebSearchTool()
complex_query = 'Python "ModuleNotFoundError" 中文 @#$%'
tool.web_search(query=complex_query)
# Verify complex query is passed correctly
call_kwargs = mock_client_instance.search.call_args[1]
assert call_kwargs["query"] == complex_query
class TestToolSpec:
"""Test suite for ToolSpec and tool configuration."""
def test_web_search_tool_spec_exists(self):
"""Test that web_search_spec is properly defined."""
assert hasattr(WebSearchTool, "web_search_spec")
spec = WebSearchTool.web_search_spec
assert spec.description is not None
assert len(spec.description) > 0
assert spec.input_schema == WebSearchInput
def test_web_search_tool_spec_description(self):
"""Test that web_search_spec description contains expected keywords."""
spec = WebSearchTool.web_search_spec
# Verify description mentions key use cases
assert "bug analysis" in spec.description
assert "error messages" in spec.description
assert "documentation" in spec.description
assert "library" in spec.description
class TestEdgeCases:
"""Test edge cases and boundary conditions."""
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_max_results_boundary(self, mock_tavily_client, mock_settings):
"""Test web search with boundary values for max_results."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {"results": []}
mock_client_instance.search.return_value = mock_response
tool = WebSearchTool()
# Test with 0
tool.web_search(query="test", max_results=0)
assert mock_client_instance.search.call_args[1]["max_results"] == 0
# Test with large number
tool.web_search(query="test", max_results=100)
assert mock_client_instance.search.call_args[1]["max_results"] == 100
def test_format_results_missing_optional_fields(self):
"""Test formatting when optional fields are missing."""
response = {
"results": [
{
"title": "Title",
"url": "https://example.com",
"content": "Content",
# No published_date
}
]
# No answer, included_domains, excluded_domains
}
result = format_results(response)
# Should not fail and should not include missing fields
assert "Published:" not in result
assert "Answer:" not in result
assert "Search Filters:" not in result
@patch("prometheus.tools.web_search.settings")
@patch("prometheus.tools.web_search.TavilyClient")
def test_web_search_with_empty_domains_lists(self, mock_tavily_client, mock_settings):
"""Test web search with empty domain lists."""
mock_settings.TAVILY_API_KEY = "test_api_key"
mock_client_instance = MagicMock()
mock_tavily_client.return_value = mock_client_instance
mock_response = {"results": []}
mock_client_instance.search.return_value = mock_response
tool = WebSearchTool()
tool.web_search(query="test", include_domains=[], exclude_domains=[])
call_kwargs = mock_client_instance.search.call_args[1]
assert call_kwargs["include_domains"] == []
assert call_kwargs["exclude_domains"] == []
def test_format_results_with_long_content(self):
"""Test formatting with very long content."""
long_content = "A" * 10000
response = {
"results": [
{
"title": "Long Content Result",
"url": "https://example.com",
"content": long_content,
}
]
}
result = format_results(response)
# Should handle long content without errors
assert long_content in result
assert "Long Content Result" in result

View File

View File

@@ -0,0 +1,35 @@
import pytest
from prometheus.exceptions.file_operation_exception import FileOperationException
from prometheus.utils.file_utils import (
read_file_with_line_numbers,
)
from tests.test_utils.test_project_paths import TEST_PROJECT_PATH
def test_read_file_with_line_numbers():
"""Test reading specific line ranges from a file."""
# Test reading specific lines
result = read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 1, 15)
expected = """1. # A
2.
3. Text under header A.
4.
5. ## B
6.
7. Text under header B.
8.
9. ## C
10.
11. Text under header C.
12.
13. ### D
14.
15. Text under header D."""
assert result == expected
# Test invalid range should raise exception
with pytest.raises(FileOperationException) as exc_info:
read_file_with_line_numbers("foo/test.md", TEST_PROJECT_PATH, 4, 2)
assert str(exc_info.value) == "The end line number must be greater than the start line number."

View File

@@ -0,0 +1,50 @@
from prometheus.utils.issue_util import (
format_issue_comments,
format_issue_info,
format_test_commands,
)
def test_format_issue_comments():
comments = [
{"username": "alice", "comment": "This looks good!"},
{"username": "bob", "comment": "Can we add tests?"},
]
result = format_issue_comments(comments)
expected = "alice: This looks good!\n\nbob: Can we add tests?"
assert result == expected
def test_format_issue_info():
title = "Bug in login flow"
body = "Users can't login on mobile devices"
comments = [
{"username": "alice", "comment": "I can reproduce this"},
{"username": "bob", "comment": "Working on a fix"},
]
result = format_issue_info(title, body, comments)
expected = """\
Issue title:
Bug in login flow
Issue description:
Users can't login on mobile devices
Issue comments:
alice: I can reproduce this
bob: Working on a fix"""
assert result == expected
def test_format_test_commands():
commands = ["pytest test_login.py", "pytest test_auth.py -v"]
result = format_test_commands(commands)
expected = "$ pytest test_login.py\n$ pytest test_auth.py -v"
assert result == expected

View File

@@ -0,0 +1,260 @@
from prometheus.models.context import Context
from prometheus.utils.knowledge_graph_utils import (
knowledge_graph_data_for_context_generator,
sort_contexts,
)
def test_empty_data():
"""Test with None or empty data"""
assert knowledge_graph_data_for_context_generator(None) == []
assert knowledge_graph_data_for_context_generator([]) == []
def test_skip_empty_content():
"""Test that empty or whitespace-only content is skipped"""
data = [
{
"FileNode": {"relative_path": "test.py"},
"ASTNode": {"text": "", "start_line": 1, "end_line": 1},
},
{
"FileNode": {"relative_path": "test.py"},
"TextNode": {"text": " \n\t ", "start_line": 2, "end_line": 2},
},
{
"FileNode": {"relative_path": "test.py"},
"ASTNode": {"text": "valid content", "start_line": 3, "end_line": 3},
},
]
result = knowledge_graph_data_for_context_generator(data)
assert len(result) == 1
assert result[0].content == "valid content"
def test_deduplication_identical_content():
"""Test deduplication of identical content"""
data = [
{
"FileNode": {"relative_path": "test.py"},
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
},
{
"FileNode": {"relative_path": "test.py"},
"TextNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
},
]
result = knowledge_graph_data_for_context_generator(data)
assert len(result) == 1
assert result[0].content == "def hello():"
def test_deduplication_content_containment():
"""Test deduplication when one content contains another"""
data = [
{
"FileNode": {"relative_path": "test.py"},
"ASTNode": {"text": "def hello():\n print('world')", "start_line": 1, "end_line": 2},
},
{
"FileNode": {"relative_path": "test.py"},
"TextNode": {"text": "print('world')", "start_line": 2, "end_line": 2},
},
]
result = knowledge_graph_data_for_context_generator(data)
assert len(result) == 1
assert result[0].content == "def hello():\n print('world')"
def test_deduplication_line_containment():
"""Test deduplication based on line number containment"""
data = [
{
"FileNode": {"relative_path": "test.py"},
"ASTNode": {"text": "function body", "start_line": 1, "end_line": 5},
},
{
"FileNode": {"relative_path": "test.py"},
"TextNode": {"text": "inner content", "start_line": 2, "end_line": 3},
},
]
result = knowledge_graph_data_for_context_generator(data)
assert len(result) == 1
assert result[0].content == "function body"
assert result[0].start_line_number == 1
assert result[0].end_line_number == 5
def test_different_files_no_deduplication():
"""Test that identical content in different files is not deduplicated"""
data = [
{
"FileNode": {"relative_path": "file1.py"},
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
},
{
"FileNode": {"relative_path": "file2.py"},
"ASTNode": {"text": "def hello():", "start_line": 1, "end_line": 1},
},
]
result = knowledge_graph_data_for_context_generator(data)
assert len(result) == 2
assert result[0].relative_path == "file1.py"
assert result[1].relative_path == "file2.py"
def test_content_stripping():
"""Test that content is properly stripped of whitespace"""
data = [
{
"FileNode": {"relative_path": "test.py"},
"ASTNode": {"text": " \n def hello(): \n ", "start_line": 1, "end_line": 1},
}
]
result = knowledge_graph_data_for_context_generator(data)
assert len(result) == 1
assert result[0].content == "def hello():"
def test_complex_deduplication_scenario():
"""Test complex scenario with multiple overlapping contexts"""
data = [
# Large context containing everything
{
"FileNode": {"relative_path": "test.py"},
"ASTNode": {
"text": "class MyClass:\n def method1(self):\n pass\n def method2(self):\n pass",
"start_line": 1,
"end_line": 5,
},
},
# Smaller context contained within the large one
{
"FileNode": {"relative_path": "test.py"},
"TextNode": {
"text": "def method1(self):\n pass",
"start_line": 2,
"end_line": 3,
},
},
# Separate context in same file
{
"FileNode": {"relative_path": "test.py"},
"SelectedLines": {"text": "# Comment at end", "start_line": 10, "end_line": 10},
},
]
result = knowledge_graph_data_for_context_generator(data)
assert len(result) == 2 # Large context + separate comment
assert "class MyClass:" in result[0].content
assert result[1].content == "# Comment at end"
def test_sort_contexts_empty_list():
"""Test sorting an empty list"""
result = sort_contexts([])
assert result == []
def test_sort_contexts_by_relative_path():
"""Test sorting by relative path"""
contexts = [
Context(
relative_path="src/z.py", content="content z", start_line_number=1, end_line_number=5
),
Context(
relative_path="src/a.py", content="content a", start_line_number=1, end_line_number=5
),
Context(
relative_path="src/m.py", content="content m", start_line_number=1, end_line_number=5
),
]
result = sort_contexts(contexts)
assert result[0].relative_path == "src/a.py"
assert result[1].relative_path == "src/m.py"
assert result[2].relative_path == "src/z.py"
def test_sort_contexts_by_line_numbers():
"""Test sorting by line numbers within same file"""
contexts = [
Context(
relative_path="test.py", content="content 3", start_line_number=20, end_line_number=25
),
Context(
relative_path="test.py", content="content 1", start_line_number=1, end_line_number=5
),
Context(
relative_path="test.py", content="content 2", start_line_number=10, end_line_number=15
),
]
result = sort_contexts(contexts)
assert result[0].start_line_number == 1
assert result[1].start_line_number == 10
assert result[2].start_line_number == 20
def test_sort_contexts_none_line_numbers():
"""Test sorting when line numbers are None (should appear last)"""
contexts = [
Context(
relative_path="test.py",
content="content with lines",
start_line_number=10,
end_line_number=15,
),
Context(
relative_path="test.py",
content="content no lines",
start_line_number=None,
end_line_number=None,
),
Context(
relative_path="test.py", content="content first", start_line_number=1, end_line_number=5
),
]
result = sort_contexts(contexts)
assert result[0].start_line_number == 1
assert result[1].start_line_number == 10
assert result[2].start_line_number is None
def test_sort_contexts_mixed_files_and_lines():
"""Test sorting with multiple files and different line numbers"""
contexts = [
Context(
relative_path="b.py", content="b content 2", start_line_number=20, end_line_number=25
),
Context(
relative_path="a.py", content="a content 2", start_line_number=10, end_line_number=15
),
Context(
relative_path="b.py", content="b content 1", start_line_number=5, end_line_number=10
),
Context(
relative_path="a.py", content="a content 1", start_line_number=1, end_line_number=5
),
]
result = sort_contexts(contexts)
assert result[0].relative_path == "a.py" and result[0].start_line_number == 1
assert result[1].relative_path == "a.py" and result[1].start_line_number == 10
assert result[2].relative_path == "b.py" and result[2].start_line_number == 5
assert result[3].relative_path == "b.py" and result[3].start_line_number == 20
def test_sort_contexts_end_line_number_tiebreaker():
"""Test sorting uses end_line_number as tiebreaker when start_line_number is same"""
contexts = [
Context(
relative_path="test.py", content="content 3", start_line_number=10, end_line_number=30
),
Context(
relative_path="test.py", content="content 1", start_line_number=10, end_line_number=15
),
Context(
relative_path="test.py", content="content 2", start_line_number=10, end_line_number=20
),
]
result = sort_contexts(contexts)
assert result[0].end_line_number == 15
assert result[1].end_line_number == 20
assert result[2].end_line_number == 30

View File

@@ -0,0 +1,108 @@
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from prometheus.utils.lang_graph_util import (
check_remaining_steps,
extract_ai_responses,
extract_human_queries,
extract_last_tool_messages,
format_agent_tool_message_history,
get_last_message_content,
)
# Test check_remaining_steps
def test_check_remaining_steps():
def mock_router(state):
return "next_step"
state_enough_steps = {"remaining_steps": 5}
state_low_steps = {"remaining_steps": 2}
assert check_remaining_steps(state_enough_steps, mock_router, 3) == "next_step"
assert check_remaining_steps(state_low_steps, mock_router, 3) == "low_remaining_steps"
# Test extract_ai_responses
def test_extract_ai_responses():
messages = [
HumanMessage(content="Human 1"),
AIMessage(content="AI 1"),
HumanMessage(content="Human 2"),
AIMessage(content="AI 2"),
]
responses = extract_ai_responses(messages)
assert len(responses) == 2
assert "AI 1" in responses
assert "AI 2" in responses
# Test extract_human_queries
def test_extract_human_queries():
messages = [
SystemMessage(content="System"),
HumanMessage(content="Human 1"),
AIMessage(content="AI 1"),
HumanMessage(content="Human 2"),
]
queries = extract_human_queries(messages)
assert len(queries) == 2
assert "Human 1" in queries
assert "Human 2" in queries
# Test extract_last_tool_messages
def test_extract_last_tool_messages():
messages = [
HumanMessage(content="Human 1"),
ToolMessage(content="Tool 1", tool_call_id="call_1"),
AIMessage(content="AI 1"),
HumanMessage(content="Human 2"),
ToolMessage(content="Tool 2", tool_call_id="call_2"),
ToolMessage(content="Tool 3", tool_call_id="call_3"),
]
tool_messages = extract_last_tool_messages(messages)
assert len(tool_messages) == 2
assert all(isinstance(msg, ToolMessage) for msg in tool_messages)
assert tool_messages[-1].content == "Tool 3"
# Test get_last_message_content
def test_get_last_message_content():
messages = [
HumanMessage(content="Human"),
AIMessage(content="AI"),
ToolMessage(content="Last message", tool_call_id="call_1"),
]
content = get_last_message_content(messages)
assert content == "Last message"
def test_format_agent_tool_message_history():
messages = [
AIMessage(content="Let me analyze this"),
AIMessage(
content="I'll use a tool for this",
additional_kwargs={"tool_calls": [{"function": "analyze_data"}]},
),
ToolMessage(content="Analysis results: Success", tool_call_id="call_1"),
]
result = format_agent_tool_message_history(messages)
expected = (
"Assistant internal thought: Let me analyze this\n\n"
"Assistant internal thought: I'll use a tool for this\n\n"
"Assistant executed tool: analyze_data\n\n"
"Tool output: Analysis results: Success"
)
assert result == expected

View File

@@ -0,0 +1,391 @@
from unittest.mock import MagicMock, patch
import pytest
import requests
from prometheus.models.context import Context
from prometheus.models.query import Query
from prometheus.utils.memory_utils import AthenaMemoryClient
@pytest.fixture
def athena_client():
"""Create an AthenaMemoryClient instance for testing."""
return AthenaMemoryClient(base_url="http://test-athena-service.com")
@pytest.fixture
def sample_contexts():
"""Create sample Context objects for testing."""
return [
Context(
relative_path="src/main.py",
content="def main():\n print('Hello')",
start_line_number=1,
end_line_number=2,
),
Context(
relative_path="src/utils.py",
content="def helper():\n return True",
start_line_number=10,
end_line_number=11,
),
]
@pytest.fixture
def sample_query():
"""Create a sample Query object for testing."""
return Query(
essential_query="How to implement authentication?",
extra_requirements="Using JWT tokens",
purpose="Feature implementation",
)
class TestAthenaMemoryClient:
"""Test suite for AthenaMemoryClient class."""
def test_init(self):
"""Test AthenaMemoryClient initialization."""
client = AthenaMemoryClient(base_url="http://example.com/")
assert client.base_url == "http://example.com"
assert client.timeout == 30
def test_init_strips_trailing_slash(self):
"""Test that trailing slashes are removed from base_url."""
client = AthenaMemoryClient(base_url="http://example.com///")
assert client.base_url == "http://example.com"
@patch("prometheus.utils.memory_utils.requests.post")
def test_store_memory_success(self, mock_post, athena_client, sample_contexts):
"""Test successful memory storage."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {"status": "success", "memory_id": "123"}
mock_post.return_value = mock_response
# Call the method
result = athena_client.store_memory(
repository_id=42,
essential_query="How to use async?",
extra_requirements="Python 3.11",
purpose="Learning",
contexts=sample_contexts,
)
# Assertions
assert result == {"status": "success", "memory_id": "123"}
mock_post.assert_called_once()
# Verify the call arguments
call_args = mock_post.call_args
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/store/"
assert call_args[1]["timeout"] == 30
# Verify payload structure
payload = call_args[1]["json"]
assert payload["repository_id"] == 42
assert payload["query"]["essential_query"] == "How to use async?"
assert payload["query"]["extra_requirements"] == "Python 3.11"
assert payload["query"]["purpose"] == "Learning"
assert len(payload["contexts"]) == 2
assert payload["contexts"][0]["relative_path"] == "src/main.py"
@patch("prometheus.utils.memory_utils.requests.post")
def test_store_memory_request_exception(self, mock_post, athena_client, sample_contexts):
"""Test store_memory raises RequestException on failure."""
# Setup mock to raise an exception
mock_post.side_effect = requests.RequestException("Connection error")
# Verify exception is raised
with pytest.raises(requests.RequestException) as exc_info:
athena_client.store_memory(
repository_id=42,
essential_query="test query",
extra_requirements="",
purpose="",
contexts=sample_contexts,
)
assert "Connection error" in str(exc_info.value)
@patch("prometheus.utils.memory_utils.requests.post")
def test_store_memory_http_error(self, mock_post, athena_client, sample_contexts):
"""Test store_memory handles HTTP errors."""
# Setup mock response with error status
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response
# Verify exception is raised
with pytest.raises(requests.HTTPError):
athena_client.store_memory(
repository_id=42,
essential_query="test query",
extra_requirements="",
purpose="",
contexts=sample_contexts,
)
@patch("prometheus.utils.memory_utils.requests.get")
def test_retrieve_memory_success(self, mock_get, athena_client, sample_query):
"""Test successful memory retrieval."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [
{"content": "Authentication code snippet 1"},
{"content": "Authentication code snippet 2"},
]
}
mock_get.return_value = mock_response
# Call the method
result = athena_client.retrieve_memory(repository_id=42, query=sample_query)
# Assertions
assert len(result) == 2
assert result[0]["content"] == "Authentication code snippet 1"
mock_get.assert_called_once()
# Verify the call arguments
call_args = mock_get.call_args
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/retrieve/42/"
assert call_args[1]["timeout"] == 30
# Verify query parameters
params = call_args[1]["params"]
assert params["essential_query"] == "How to implement authentication?"
assert params["extra_requirements"] == "Using JWT tokens"
assert params["purpose"] == "Feature implementation"
@patch("prometheus.utils.memory_utils.requests.get")
def test_retrieve_memory_with_optional_fields(self, mock_get, athena_client):
"""Test memory retrieval with empty optional query fields."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {"data": []}
mock_get.return_value = mock_response
# Create query with empty optional fields
query = Query(essential_query="test query", extra_requirements="", purpose="")
# Call the method
result = athena_client.retrieve_memory(repository_id=42, query=query)
# Verify empty strings are passed correctly
call_args = mock_get.call_args
params = call_args[1]["params"]
assert params["extra_requirements"] == ""
assert params["purpose"] == ""
assert result == []
@patch("prometheus.utils.memory_utils.requests.get")
def test_retrieve_memory_request_exception(self, mock_get, athena_client, sample_query):
"""Test retrieve_memory raises RequestException on failure."""
# Setup mock to raise an exception
mock_get.side_effect = requests.RequestException("Timeout error")
# Verify exception is raised
with pytest.raises(requests.RequestException) as exc_info:
athena_client.retrieve_memory(repository_id=42, query=sample_query)
assert "Timeout error" in str(exc_info.value)
@patch("prometheus.utils.memory_utils.requests.get")
def test_retrieve_memory_http_error(self, mock_get, athena_client, sample_query):
"""Test retrieve_memory handles HTTP errors."""
# Setup mock response with error status
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
mock_get.return_value = mock_response
# Verify exception is raised
with pytest.raises(requests.HTTPError):
athena_client.retrieve_memory(repository_id=42, query=sample_query)
@patch("prometheus.utils.memory_utils.requests.delete")
def test_delete_repository_memory_success(self, mock_delete, athena_client):
"""Test successful repository memory deletion."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {"status": "success", "deleted_count": 15}
mock_delete.return_value = mock_response
# Call the method
result = athena_client.delete_repository_memory(repository_id=42)
# Assertions
assert result == {"status": "success", "deleted_count": 15}
mock_delete.assert_called_once()
# Verify the call arguments
call_args = mock_delete.call_args
assert call_args[0][0] == "http://test-athena-service.com/semantic-memory/42/"
assert call_args[1]["timeout"] == 30
@patch("prometheus.utils.memory_utils.requests.delete")
def test_delete_repository_memory_request_exception(self, mock_delete, athena_client):
"""Test delete_repository_memory raises RequestException on failure."""
# Setup mock to raise an exception
mock_delete.side_effect = requests.RequestException("Network error")
# Verify exception is raised
with pytest.raises(requests.RequestException) as exc_info:
athena_client.delete_repository_memory(repository_id=42)
assert "Network error" in str(exc_info.value)
@patch("prometheus.utils.memory_utils.requests.delete")
def test_delete_repository_memory_http_error(self, mock_delete, athena_client):
"""Test delete_repository_memory handles HTTP errors."""
# Setup mock response with error status
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = requests.HTTPError("403 Forbidden")
mock_delete.return_value = mock_response
# Verify exception is raised
with pytest.raises(requests.HTTPError):
athena_client.delete_repository_memory(repository_id=42)
class TestModuleLevelFunctions:
"""Test suite for module-level convenience functions."""
@patch("prometheus.utils.memory_utils.athena_client.store_memory")
def test_store_memory_function(self, mock_store, sample_contexts):
"""Test module-level store_memory function."""
from prometheus.utils.memory_utils import store_memory
# Setup mock
mock_store.return_value = {"status": "success"}
# Call the function
result = store_memory(
repository_id=123,
essential_query="test query",
extra_requirements="requirements",
purpose="testing",
contexts=sample_contexts,
)
# Verify it delegates to the client
assert result == {"status": "success"}
mock_store.assert_called_once_with(
repository_id=123,
essential_query="test query",
extra_requirements="requirements",
purpose="testing",
contexts=sample_contexts,
)
@patch("prometheus.utils.memory_utils.athena_client.retrieve_memory")
def test_retrieve_memory_function(self, mock_retrieve, sample_query):
"""Test module-level retrieve_memory function."""
from prometheus.utils.memory_utils import retrieve_memory
# Setup mock
mock_retrieve.return_value = [{"content": "test"}]
# Call the function
result = retrieve_memory(repository_id=123, query=sample_query)
# Verify it delegates to the client
assert result == [{"content": "test"}]
mock_retrieve.assert_called_once_with(repository_id=123, query=sample_query)
@patch("prometheus.utils.memory_utils.athena_client.delete_repository_memory")
def test_delete_repository_memory_function(self, mock_delete):
"""Test module-level delete_repository_memory function."""
from prometheus.utils.memory_utils import delete_repository_memory
# Setup mock
mock_delete.return_value = {"status": "deleted"}
# Call the function
result = delete_repository_memory(repository_id=123)
# Verify it delegates to the client
assert result == {"status": "deleted"}
mock_delete.assert_called_once_with(repository_id=123)
class TestEdgeCases:
"""Test edge cases and boundary conditions."""
@patch("prometheus.utils.memory_utils.requests.post")
def test_store_memory_with_empty_contexts_list(self, mock_post, athena_client):
"""Test storing memory with empty contexts list."""
mock_response = MagicMock()
mock_response.json.return_value = {"status": "success"}
mock_post.return_value = mock_response
result = athena_client.store_memory(
repository_id=1,
essential_query="test",
extra_requirements="",
purpose="",
contexts=[],
)
# Verify empty list is handled
payload = mock_post.call_args[1]["json"]
assert payload["contexts"] == []
assert result == {"status": "success"}
@patch("prometheus.utils.memory_utils.requests.get")
def test_retrieve_memory_empty_results(self, mock_get, athena_client):
"""Test retrieving memory when no results are found."""
mock_response = MagicMock()
mock_response.json.return_value = {"data": []}
mock_get.return_value = mock_response
query = Query(
essential_query="nonexistent query",
extra_requirements="",
purpose="",
)
result = athena_client.retrieve_memory(repository_id=999, query=query)
assert result == []
def test_context_serialization(self, sample_contexts):
"""Test that Context objects are properly serialized."""
context = sample_contexts[0]
serialized = context.model_dump()
assert serialized["relative_path"] == "src/main.py"
assert serialized["content"] == "def main():\n print('Hello')"
assert serialized["start_line_number"] == 1
assert serialized["end_line_number"] == 2
@patch("prometheus.utils.memory_utils.requests.post")
def test_store_memory_with_special_characters(self, mock_post, athena_client):
"""Test storing memory with special characters in content."""
mock_response = MagicMock()
mock_response.json.return_value = {"status": "success"}
mock_post.return_value = mock_response
special_context = Context(
relative_path="test/file.py",
content="def test():\n # Special chars: @#$%^&*(){}[]|\\<>?~`",
start_line_number=1,
end_line_number=2,
)
result = athena_client.store_memory(
repository_id=1,
essential_query="test with special chars: @#$%",
extra_requirements="requirements with 中文",
purpose="purpose with émojis 🚀",
contexts=[special_context],
)
assert result == {"status": "success"}
# Verify special characters are preserved in the payload
payload = mock_post.call_args[1]["json"]
assert "@#$%" in payload["query"]["essential_query"]
assert "中文" in payload["query"]["extra_requirements"]
assert "🚀" in payload["query"]["purpose"]

View File

@@ -0,0 +1,86 @@
from unittest.mock import MagicMock
import pytest
from prometheus.utils.knowledge_graph_utils import (
EMPTY_DATA_MESSAGE,
format_knowledge_graph_data,
)
class MockResult:
def __init__(self, data_list):
self._data = data_list
def data(self):
return self._data
class MockSession:
def __init__(self):
self.execute_read = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class MockDriver:
def __init__(self, session):
self._session = session
def session(self):
return self._session
@pytest.fixture
def mock_neo4j_driver():
session = MockSession()
driver = MockDriver(session)
return driver, session
def test_format_neo4j_data_single_row():
data = [{"name": "John", "age": 30}]
formatted = format_knowledge_graph_data(data)
expected = "Result 1:\nage: 30\nname: John"
assert formatted == expected
def test_format_neo4j_result_multiple_rows():
data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]
formatted = format_knowledge_graph_data(data)
expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\nage: 25\nname: Jane"
assert formatted == expected
def test_format_neo4j_result_empty():
data = []
formatted = format_knowledge_graph_data(data)
assert formatted == EMPTY_DATA_MESSAGE
def test_format_neo4j_result_different_keys():
data = [{"name": "John", "age": 30}, {"city": "New York", "country": "USA"}]
formatted = format_knowledge_graph_data(data)
expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\ncity: New York\ncountry: USA"
assert formatted == expected
def test_format_neo4j_result_complex_values():
data = [
{"numbers": [1, 2, 3], "metadata": {"type": "user", "active": True}, "date": "2024-01-01"}
]
formatted = format_knowledge_graph_data(data)
expected = "Result 1:\ndate: 2024-01-01\nmetadata: {'type': 'user', 'active': True}\nnumbers: [1, 2, 3]"
assert formatted == expected

View File

@@ -0,0 +1,120 @@
from pathlib import Path
from prometheus.utils.patch_util import get_updated_files
def test_get_updated_files_empty_diff():
diff = ""
added, modified, removed = get_updated_files(diff)
assert len(added) == 0
assert len(modified) == 0
assert len(removed) == 0
def test_get_updated_files_added_only():
diff = """
diff --git a/new_file.txt b/new_file.txt
new file mode 100644
index 0000000..1234567
--- /dev/null
+++ b/new_file.txt
@@ -0,0 +1 @@
+New content
"""
added, modified, removed = get_updated_files(diff)
assert len(added) == 1
assert len(modified) == 0
assert len(removed) == 0
assert added[0] == Path("new_file.txt")
def test_get_updated_files_modified_only():
diff = """
diff --git a/modified_file.txt b/modified_file.txt
index 1234567..89abcdef
--- a/modified_file.txt
+++ b/modified_file.txt
@@ -1 +1 @@
-Old content
+Modified content
"""
added, modified, removed = get_updated_files(diff)
assert len(added) == 0
assert len(modified) == 1
assert len(removed) == 0
assert modified[0] == Path("modified_file.txt")
def test_get_updated_files_removed_only():
diff = """
diff --git a/removed_file.txt b/removed_file.txt
deleted file mode 100644
index 1234567..0000000
--- a/removed_file.txt
+++ /dev/null
@@ -1 +0,0 @@
-Content to be removed
"""
added, modified, removed = get_updated_files(diff)
assert len(added) == 0
assert len(modified) == 0
assert len(removed) == 1
assert removed[0] == Path("removed_file.txt")
def test_get_updated_files_multiple_changes():
diff = """
diff --git a/new_file.txt b/new_file.txt
new file mode 100644
index 0000000..1234567
--- /dev/null
+++ b/new_file.txt
@@ -0,0 +1 @@
+New content
diff --git a/modified_file.txt b/modified_file.txt
index 1234567..89abcdef
--- a/modified_file.txt
+++ b/modified_file.txt
@@ -1 +1 @@
-Old content
+Modified content
diff --git a/removed_file.txt b/removed_file.txt
deleted file mode 100644
index 1234567..0000000
--- a/removed_file.txt
+++ /dev/null
@@ -1 +0,0 @@
-Content to be removed
"""
added, modified, removed = get_updated_files(diff)
assert len(added) == 1
assert len(modified) == 1
assert len(removed) == 1
assert added[0] == Path("new_file.txt")
assert modified[0] == Path("modified_file.txt")
assert removed[0] == Path("removed_file.txt")
def test_get_updated_files_with_subfolders():
diff = """
diff --git a/folder1/new_file.txt b/folder1/new_file.txt
new file mode 100644
index 0000000..1234567
--- /dev/null
+++ b/folder1/new_file.txt
@@ -0,0 +1 @@
+New content
diff --git a/folder2/subfolder/modified_file.txt b/folder2/subfolder/modified_file.txt
index 1234567..89abcdef
--- a/folder2/subfolder/modified_file.txt
+++ b/folder2/subfolder/modified_file.txt
@@ -1 +1 @@
-Old content
+Modified content
"""
added, modified, removed = get_updated_files(diff)
assert len(added) == 1
assert len(modified) == 1
assert len(removed) == 0
assert added[0] == Path("folder1/new_file.txt")
assert modified[0] == Path("folder2/subfolder/modified_file.txt")

View File

@@ -0,0 +1,21 @@
from prometheus.utils.str_util import (
pre_append_line_numbers,
)
def test_single_line():
text = "Hello world"
result = pre_append_line_numbers(text, start_line=1)
assert result == "1. Hello world"
def test_multiple_lines():
text = "First line\nSecond line\nThird line"
result = pre_append_line_numbers(text, start_line=1)
assert result == "1. First line\n2. Second line\n3. Third line"
def test_empty_string():
text = ""
result = pre_append_line_numbers(text, start_line=1)
assert result == ""