Skip to content

Add OpenAI integration with CLI for generating Mermaid diagrams. #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ NEXT_PUBLIC_API_DEV_URL=http://localhost:8000
ANTHROPIC_API_KEY=

# OPTIONAL: providing your own GitHub PAT increases rate limits from 60/hr to 5000/hr to the GitHub API
GITHUB_PAT=
GITHUB_PAT=

# OpenAI API configuration for CLI usage
OPENAI_API_KEY=""
OPENAI_BASE_URL="https://api.openai.com/v1"
OPENAI_MODEL="gpt-4o-mini"
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest

if: github.repository_owner == 'ahmedkhaleel2004'
# Add concurrency to prevent multiple deployments running at once
concurrency:
group: production
Expand Down
65 changes: 65 additions & 0 deletions backend/app/services/openai_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from dotenv import load_dotenv
import openai

load_dotenv()


class OpenAIService:
def __init__(self):
self.api_key = os.getenv("OPENAI_API_KEY")
self.base_url = os.getenv("OPENAI_BASE_URL")
self.model = os.getenv("OPENAI_MODEL")
self.client = openai.OpenAI(
api_key=self.api_key,
base_url=self.base_url,
)

def call_openai_api(self, system_prompt: str, data: dict) -> str:
"""
Makes an API call to OpenAI and returns the response.

Args:
system_prompt (str): The instruction/system prompt
data (dict): Dictionary of variables to format into the user message

Returns:
str: OpenAI's response text
"""
# Format the user message
user_message = self._format_user_message(data)

messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]

try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=4096,
temperature=0,
)
return response.choices[0].message.content
except Exception as e:
raise Exception(f"API call failed: {str(e)}")

def _format_user_message(self, data: dict[str, str]) -> str:
"""Helper method to format the data into a user message"""
parts = []
for key, value in data.items():
if key == 'file_tree':
parts.append(f"<file_tree>\n{value}\n</file_tree>")
elif key == 'readme':
parts.append(f"<readme>\n{value}\n</readme>")
elif key == 'explanation':
parts.append(f"<explanation>\n{value}\n</explanation>")
elif key == 'component_mapping':
parts.append(
f"<component_mapping>\n{value}\n</component_mapping>")
elif key == 'instructions' and value != "":
parts.append(f"<instructions>\n{value}\n</instructions>")
elif key == 'diagram':
parts.append(f"<diagram>\n{value}\n</diagram>")
return "\n\n".join(parts)
1 change: 1 addition & 0 deletions backend/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .local_git import *
128 changes: 128 additions & 0 deletions backend/cli/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import argparse
from app.services.openai_service import OpenAIService
from app.prompts import SYSTEM_FIRST_PROMPT, SYSTEM_SECOND_PROMPT, SYSTEM_THIRD_PROMPT, ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
import sys
from cli import build_file_tree, get_readme, print_stat


def main():
parser = argparse.ArgumentParser(
description="Generate Mermaid diagrams from local Git repositories.")
parser.add_argument("repo_path", help="Path to the local Git repository")
parser.add_argument(
"--instructions", help="Instructions for diagram generation", default=None)
parser.add_argument(
"--output", help="Output file for the Mermaid diagram", default="diagram.mmd")
parser.add_argument(
"--stat",
help="Only outputs the file list and statistics",
action="store_true"
)

args = parser.parse_args()

repo_path = args.repo_path
instructions = args.instructions
output_file = args.output

if not os.path.isdir(repo_path):
print(f"Error: The path '{repo_path}' is not a valid directory.")
sys.exit(1)

openai_service = OpenAIService()

if (args.stat):
print_stat(repo_path)
return

# Build file tree and get README
file_tree = build_file_tree(repo_path)
readme = get_readme(repo_path)

if not file_tree and not readme:
print("Error: The repository is empty or unreadable.")
sys.exit(1)

# Prepare system prompts with instructions if provided
first_system_prompt = SYSTEM_FIRST_PROMPT
third_system_prompt = SYSTEM_THIRD_PROMPT
if instructions:
first_system_prompt += "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
third_system_prompt += "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
else:
instructions = ""

# Call OpenAI API to get explanation
try:
explanation = openai_service.call_openai_api(
system_prompt=first_system_prompt,
data={
"file_tree": file_tree,
"readme": readme,
"instructions": instructions
},
)
except Exception as e:
print(f"Error generating explanation: {e}")
sys.exit(1)

if "BAD_INSTRUCTIONS" in explanation:
print("Error: Invalid or unclear instructions provided.")
sys.exit(1)

# Call API to get component mapping
try:
full_second_response = openai_service.call_openai_api(
system_prompt=SYSTEM_SECOND_PROMPT,
data={
"explanation": explanation,
"file_tree": file_tree
}
)
except Exception as e:
print(f"Error generating component mapping: {e}")
sys.exit(1)

# Extract component mapping from the response
start_tag = "<component_mapping>"
end_tag = "</component_mapping>"
try:
component_mapping_text = full_second_response[
full_second_response.find(start_tag):
full_second_response.find(end_tag) + len(end_tag)
]
except Exception:
print("Error extracting component mapping.")
sys.exit(1)

# Call API to get Mermaid diagram
try:
mermaid_code = openai_service.call_openai_api(
system_prompt=third_system_prompt,
data={
"explanation": explanation,
"component_mapping": component_mapping_text,
"instructions": instructions
}
)
except Exception as e:
print(f"Error generating Mermaid diagram: {e}")
sys.exit(1)

if "BAD_INSTRUCTIONS" in mermaid_code:
print("Error: Invalid or unclear instructions provided.")
sys.exit(1)

# Save the diagram to the output file
try:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(mermaid_code)
print(f"Mermaid diagram generated and saved to '{output_file}'.")
except Exception as e:
print(f"Error saving diagram: {e}")
sys.exit(1)


if __name__ == "__main__":
main()
121 changes: 121 additions & 0 deletions backend/cli/local_git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import pathspec
from collections import Counter


def build_file_tree(repo_path):
"""
Traverse the local repository and build a file tree list.
"""
excluded_patterns = [
'node_modules/',
'vendor/',
'venv/',
'__pycache__/',
'.cache/',
'.tmp/',
'.vscode/',
'.idea/',
'.git/',
'test/',
'*.min.*',
'*.pyc',
'*.pyo',
'*.pyd',
'*.so',
'*.dll',
'*.class',
'*.o',
'*.jpg',
'*.jpeg',
'*.png',
'*.gif',
'*.ico',
'*.svg',
'*.ttf',
'*.woff',
'*.webp',
'*.pdf',
'*.xml',
'*.wav',
'*.doc',
'*.docx',
'*.xls',
'*.xlsx',
'*.ppt',
'*.pptx',
'*.txt',
'*.log',
'yarn.lock',
'poetry.lock',
]

gitignore_path = os.path.join(repo_path, '.gitignore')
if os.path.exists(gitignore_path):
with open(gitignore_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
excluded_patterns.append(line)

spec = pathspec.PathSpec.from_lines('gitwildmatch', excluded_patterns)

file_paths = []
for root, dirs, files in os.walk(repo_path):
dirs[:] = [
d for d in dirs
if not spec.match_file(os.path.relpath(os.path.join(root, d), repo_path) + '/')
]

for file in files:
file_path = os.path.join(root, file)
rel_path = os.path.relpath(file_path, repo_path).replace("\\", "/")
if not spec.match_file(rel_path):
file_paths.append(rel_path)

return file_paths


def get_readme(repo_path):
"""
Fetch the README content from the local repository.
"""
readme_path = os.path.join(repo_path, "README.md")
if os.path.exists(readme_path):
with open(readme_path, 'r', encoding='utf-8') as f:
return f.read()
return ""


def analyze_extension_percentage(file_paths):
"""
Analyze the percentage distribution of file extensions in the provided file list.

Args:
file_paths (list): List of file paths.

Returns:
dict: Dictionary mapping file extensions to their percentage occurrence.
"""
extensions = [os.path.splitext(file)[1].lower()
for file in file_paths if os.path.splitext(file)[1]]
total = len(extensions)
if total == 0:
return {}
counts = Counter(extensions)
percentages = {ext: (count / total) * 100 for ext, count in counts.items()}

sorted_percentages = dict(
sorted(percentages.items(), key=lambda item: item[1], reverse=True))
return sorted_percentages


def print_stat(repo_path):
file_list = build_file_tree(repo_path)
for f in file_list:
print(f)
extension_percentages = analyze_extension_percentage(file_list)

print("File Extension Percentage Distribution:")
for ext, percent in extension_percentages.items():
print(f"{ext or 'No Extension'}: {percent:.2f}%")
2 changes: 2 additions & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ uvloop==0.21.0
watchfiles==1.0.3
websockets==14.1
wrapt==1.17.0
openai==1.58.1
pathspec==0.12.1