Pārlūkot izejas kodu

feat: Add file content retrieval for improved AI context

This commit introduces functionality to retrieve the content
of project files and include it in the prompt sent to the AI.
This allows the AI to generate more informed and accurate
commit messages by considering the broader context of the
changes. The AI can now request specific file contents, which
are then passed back to the AI if the user approves, enhancing
the quality of generated commit messages. Also fixes the
gemini model to use the latest model for better accuracy.
seno 6 dienas atpakaļ
vecāks
revīzija
62fbf62764
1 mainītis faili ar 155 papildinājumiem un 11 dzēšanām
  1. 155 11
      git_commit_ai.py

+ 155 - 11
git_commit_ai.py

@@ -1,6 +1,7 @@
 import subprocess
 import os
 import google.generativeai as genai
+import re
 
 
 def get_staged_diff():
@@ -33,6 +34,48 @@ def get_staged_diff():
         return None
 
 
+def get_project_files():
+    """Gets a list of all files tracked in the latest commit (HEAD)."""
+    try:
+        process = subprocess.run(
+            # Changed command to list files in the last commit
+            ["git", "ls-tree", "-r", "--name-only", "HEAD"],
+            capture_output=True,
+            text=True,
+            check=True,
+            cwd=os.getcwd(),  # Ensure it runs in the correct directory
+        )
+        return process.stdout.splitlines()
+    except subprocess.CalledProcessError as e:
+        print(f"Error getting project file list: {e}")
+        print(f"  stderr: {e.stderr}")
+        return []  # Return empty list on error
+    except FileNotFoundError:
+        print("Error: git command not found. Is Git installed and in your PATH?")
+        return []
+    except Exception as e:
+        print(f"An unexpected error occurred while listing files: {e}")
+        return []
+
+
+def get_file_content(filepath):
+    """Reads the content of a file relative to the script's CWD."""
+    # Consider adding checks to prevent reading files outside the repo
+    try:
+        # Assuming the script runs from the repo root
+        with open(filepath, "r", encoding="utf-8") as f:
+            return f.read()
+    except FileNotFoundError:
+        print(f"Warning: File not found: {filepath}")
+        return None
+    except IsADirectoryError:
+        print(f"Warning: Path is a directory, not a file: {filepath}")
+        return None
+    except Exception as e:
+        print(f"Warning: Error reading file {filepath}: {e}")
+        return None
+
+
 def generate_commit_message(diff, gemini_api_key):
     """
     Generates a commit message using the Gemini API, given the diff.
@@ -49,35 +92,136 @@ def generate_commit_message(diff, gemini_api_key):
         return None
 
     genai.configure(api_key=gemini_api_key)
-    model = genai.GenerativeModel("gemini-2.0-flash")
+    # Use the intended model name
+    model = genai.GenerativeModel("gemini-1.5-flash")
 
+    project_files_list = get_project_files()
+
+    # Define prompt as a regular string, not f-string, placeholders will be filled by .format()
     prompt = f"""
-    You are a helpful assistant that generates Git commit messages.
-    Analyze the following diff of staged files and generate a commit message adhering to standard Git conventions:
+    You are an expert assistant that generates Git commit messages following conventional commit standards.
+    Analyze the following diff of staged files and generate ONLY the commit message (subject and body) adhering to standard Git conventions.
 
-    1.  **Subject Line:** Write a concise, imperative subject line summarizing the change (max 50 characters). Start with a capital letter. Do not end with a period.
+    1.  **Subject Line:** Write a concise, imperative subject line summarizing the change (max 50 characters). Start with a capital letter. Do not end with a period. Use standard commit types like 'feat:', 'fix:', 'refactor:', 'docs:', 'test:', 'chore:', etc.
     2.  **Blank Line:** Leave a single blank line between the subject and the body.
-    3.  **Body:** Write a detailed but precise body explaining the 'what' and 'why' of the changes. Wrap lines at 72 characters. Focus on the motivation for the change and contrast its implementation with the previous behavior.
+    3.  **Body:** Write a detailed but precise body explaining the 'what' and 'why' of the changes. Wrap lines at 72 characters. Focus on the motivation for the change and contrast its implementation with the previous behavior. If the change is trivial, the body can be omitted.
+
+    **Project Files:**
+    Here is a list of files in the project:
+    ```
+    {project_files_list}
+    ```
+
+    **Contextual Understanding:**
+    *   The diff shows changes in the context of the project files listed above.
+    *   If understanding the relationship between the changed files and other parts of the project is necessary to write an accurate commit message, you may request the content of specific files from the list above.
+    *   To request file content, respond *only* with the exact phrase: `Request content for file: <path/to/file>` where `<path/to/file>` is the relative path from the repository root. Do not add any other text to your response if you are requesting a file.
 
     Diff:
     ```diff
     {diff}
     ```
 
-    Generate only the commit message (subject and body).
+    Generate ONLY the commit message text, without any introductory phrases like "Here is the commit message:", unless you need to request file content.
     """
 
     try:
-        response = model.generate_content(prompt)
-        # Check for a successful response.
-        if response and response.text:
-            return response.text.strip()  # Remove leading/trailing whitespace
+        # Get project files to include in the prompt
+        project_files = get_project_files()
+        project_files_list = (
+            "\n".join(project_files)
+            if project_files
+            else "(Could not list project files)"
+        )
+
+        # Format the prompt with the diff and file list
+        formatted_prompt = prompt.format(
+            diff=diff, project_files_list=project_files_list
+        )
+
+        # Use a conversation history for potential back-and-forth
+        conversation = [formatted_prompt]
+        max_requests = 3  # Limit the number of file requests
+        requests_made = 0
+
+        while requests_made < max_requests:
+            response = model.generate_content("\n".join(conversation))
+            message = response.text.strip()
+
+            # Check if the AI is requesting a file
+            request_match = re.match(r"^Request content for file: (.*)$", message)
+
+            if request_match:
+                filepath = request_match.group(1).strip()
+                print(f"AI requests content for: {filepath}")
+
+                user_input = input(f"Allow access to '{filepath}'? (y/n): ").lower()
+
+                if user_input == "y":
+                    file_content = get_file_content(filepath)
+                    if file_content:
+                        # Provide content to AI
+                        conversation.append(
+                            f"Response for file '{filepath}':\n```\n{file_content}\n```\nNow, generate the commit message based on the diff and this context."
+                        )
+                    else:
+                        # Inform AI file couldn't be read
+                        conversation.append(
+                            f"File '{filepath}' could not be read or was not found. Continue generating the commit message based on the original diff."
+                        )
+                else:
+                    # Inform AI permission denied
+                    conversation.append(
+                        f"User denied access to file '{filepath}'. Continue generating the commit message based on the original diff."
+                    )
+
+                requests_made += 1
+            else:
+                # AI did not request a file, assume it's the commit message
+                break  # Exit the loop
+        else:
+            # Max requests reached
+            print(
+                "Warning: Maximum number of file requests reached. Generating commit message without further context."
+            )
+            # Make one last attempt to generate the message without the last request fulfilled
+            response = model.generate_content(
+                "\n".join(conversation[:-1])
+                + "\nGenerate the commit message now based on the available information."
+            )  # Use conversation up to the last request
+            message = response.text.strip()
+
+        # Extract the final message, remove potential markdown code blocks, and strip whitespace
+        # Ensure message is not None before processing
+        if message:
+            message = re.sub(
+                r"^\s*```[a-zA-Z]*\s*\n?", "", message, flags=re.MULTILINE
+            )  # Remove leading code block start
+            message = re.sub(
+                r"\n?```\s*$", "", message, flags=re.MULTILINE
+            )  # Remove trailing code block end
+            message = message.strip()  # Strip leading/trailing whitespace
         else:
-            print("Error: Gemini API returned an empty or invalid response.")
+            # Handle case where response.text might be None or empty after failed requests
+            print(
+                "Error: Failed to get a valid response from the AI after handling requests."
+            )
             return None
 
+        # Basic validation: Check if the message seems plausible (not empty, etc.)
+        if not message or len(message) < 5:  # Arbitrary short length check
+            print(
+                f"Warning: Generated commit message seems too short or empty: '{message}'"
+            )
+            # Optionally, you could add retry logic here or return None
+
+        return message
     except Exception as e:
+        # Provide more context in the error message
         print(f"Error generating commit message with Gemini: {e}")
+        # Consider logging response details if available, e.g., response.prompt_feedback
+        if hasattr(response, "prompt_feedback"):
+            print(f"Prompt Feedback: {response.prompt_feedback}")
         return None