123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346 |
- import subprocess
- import os
- import google.generativeai as genai
- import re
- import argparse
- def get_staged_diff(amend=False):
- """
- Retrieves the diff of staged files using git.
- Returns:
- str: The diff of the staged files, or None on error.
- """
- try:
-
- if amend:
- process = subprocess.run(
- [
- "git",
- "diff",
- "HEAD~1",
- "--staged",
- ],
- capture_output=True,
- text=True,
- check=True,
- )
- else:
- process = subprocess.run(
- [
- "git",
- "diff",
- "--staged",
- ],
- capture_output=True,
- text=True,
- check=True,
- )
- return process.stdout
- except subprocess.CalledProcessError as e:
- print(f"Error getting staged diff: {e}")
- print(f" stderr: {e.stderr}")
- return None
- except FileNotFoundError:
- print(
- "Error: git command not found. Please ensure Git is installed and in your PATH."
- )
- return None
- except Exception as e:
- print(f"An unexpected error occurred: {e}")
- return None
- def get_project_files():
- """Gets a list of all files tracked in the latest commit (HEAD)."""
- try:
- process = subprocess.run(
-
- ["git", "ls-tree", "-r", "--name-only", "HEAD"],
- capture_output=True,
- text=True,
- check=True,
- cwd=os.getcwd(),
- )
- return process.stdout.splitlines()
- except subprocess.CalledProcessError as e:
- print(f"Error getting project file list: {e}")
- print(f" stderr: {e.stderr}")
- return []
- except FileNotFoundError:
- print("Error: git command not found. Is Git installed and in your PATH?")
- return []
- except Exception as e:
- print(f"An unexpected error occurred while listing files: {e}")
- return []
- def get_file_content(filepath):
- """Reads the content of a file relative to the script's CWD."""
-
- try:
-
- with open(filepath, "r", encoding="utf-8") as f:
- return f.read()
- except FileNotFoundError:
- print(f"Warning: File not found: {filepath}")
- return None
- except IsADirectoryError:
- print(f"Warning: Path is a directory, not a file: {filepath}")
- return None
- except Exception as e:
- print(f"Warning: Error reading file {filepath}: {e}")
- return None
- def generate_commit_message(diff, gemini_api_key):
- """
- Generates a commit message using the Gemini API, given the diff.
- Args:
- diff (str): The diff of the staged files.
- gemini_api_key (str): Your Gemini API key.
- Returns:
- str: The generated commit message, or None on error.
- """
- if not diff:
- print("Error: No diff provided to generate commit message.")
- return None
- genai.configure(api_key=gemini_api_key)
-
- model = genai.GenerativeModel("gemini-2.0-flash")
-
- prompt = """
- You are an expert assistant that generates Git commit messages following conventional commit standards.
- Analyze the following diff of staged files and generate ONLY the commit message (subject and body) adhering to standard Git conventions.
- 1. **Subject Line:** Write a concise, imperative subject line summarizing the change (max 50 characters). Start with a capital letter. Do not end with a period. Use standard commit types like 'feat:', 'fix:', 'refactor:', 'docs:', 'test:', 'chore:', etc.
- 2. **Blank Line:** Leave a single blank line between the subject and the body.
- 3. **Body:** Write a detailed but precise body explaining the 'what' and 'why' of the changes. Wrap lines at 72 characters. Focus on the motivation for the change and contrast its implementation with the previous behavior. If the change is trivial, the body can be omitted.
- **Project Files:**
- Here is a list of files in the project:
- ```
- {project_files_list}
- ```
- **Contextual Understanding:**
- * The diff shows changes in the context of the project files listed above.
- * If understanding the relationship between the changed files and other parts of the project is necessary to write an accurate commit message, you may request the content of specific files from the list above.
- * To request file content, respond *only* with the exact phrase: `Request content for file: <path/to/file>` where `<path/to/file>` is the relative path from the repository root. Do not add any other text to your response if you are requesting a file.
- Diff:
- ```diff
- {diff}
- ```
- Generate ONLY the commit message text, without any introductory phrases like "Here is the commit message:", unless you need to request file content.
- """
- try:
-
- project_files = get_project_files()
- project_files_list = (
- "\n".join(project_files)
- if project_files
- else "(Could not list project files)"
- )
-
- formatted_prompt = prompt.format(
- diff=diff, project_files_list=project_files_list
- )
-
- conversation = [formatted_prompt]
- max_requests = 5
- requests_made = 0
- while requests_made < max_requests:
- response = model.generate_content("\n".join(conversation))
- message = response.text.strip()
-
- request_match = re.match(r"^Request content for file: (.*)$", message)
- if request_match:
- filepath = request_match.group(1).strip()
- print(f"AI requests content for: {filepath}")
- user_input = input(f"Allow access to '{filepath}'? (y/n): ").lower()
- if user_input == "y":
- file_content = get_file_content(filepath)
- if file_content:
-
- conversation.append(
- f"Response for file '{filepath}':\n```\n{file_content}\n```\nNow, generate the commit message based on the diff and this context."
- )
- else:
-
- conversation.append(
- f"File '{filepath}' could not be read or was not found. Continue generating the commit message based on the original diff."
- )
- else:
-
- conversation.append(
- f"User denied access to file '{filepath}'. Continue generating the commit message based on the original diff."
- )
- requests_made += 1
- else:
-
- break
- else:
-
- print(
- "Warning: Maximum number of file requests reached. Generating commit message without further context."
- )
-
- response = model.generate_content(
- "\n".join(conversation[:-1])
- + "\nGenerate the commit message now based on the available information."
- )
- message = response.text.strip()
-
-
- if message:
- message = re.sub(
- r"^\s*```[a-zA-Z]*\s*\n?", "", message, flags=re.MULTILINE
- )
- message = re.sub(
- r"\n?```\s*$", "", message, flags=re.MULTILINE
- )
- message = message.strip()
- else:
-
- print(
- "Error: Failed to get a valid response from the AI after handling requests."
- )
- return None
-
- if not message or len(message) < 5:
- print(
- f"Warning: Generated commit message seems too short or empty: '{message}'"
- )
-
- return message
- except Exception as e:
-
- print(f"Error generating commit message with Gemini: {e}")
-
- if hasattr(response, "prompt_feedback"):
- print(f"Prompt Feedback: {response.prompt_feedback}")
- return None
- def create_commit(message, amend=False):
- """
- Creates a git commit with the given message, optionally amending the previous commit.
- Args:
- message (str): The commit message.
- Returns:
- bool: True if the commit was successful, False otherwise.
- """
- if not message:
- print("Error: No commit message provided.")
- return False
- try:
-
- command = ["git", "commit"]
- if amend:
- command.append("--amend")
- command.extend(["-m", message])
- process = subprocess.run(
- command,
- check=True,
- capture_output=True,
- text=True,
- )
- print(process.stdout)
- return True
- except subprocess.CalledProcessError as e:
- print(f"Error creating git commit: {e}")
- print(e.stderr)
- return False
- except FileNotFoundError:
- print("Error: git command not found. Is Git installed and in your PATH?")
- return False
- except Exception as e:
- print(f"An unexpected error occurred: {e}")
- return False
- def main():
- """
- Main function to orchestrate the process of:
- 1. Parsing arguments (for --amend).
- 2. Getting the staged diff.
- 3. Generating a commit message using Gemini.
- 4. Creating or amending a git commit with the generated message.
- """
-
- parser = argparse.ArgumentParser(
- description="Generate Git commit messages using AI."
- )
- parser.add_argument(
- "-a",
- "--amend",
- action="store_true",
- help="Amend the previous commit instead of creating a new one.",
- )
- args = parser.parse_args()
-
- gemini_api_key = os.environ.get("GEMINI_API_KEY")
- if not gemini_api_key:
- print(
- "Error: GEMINI_API_KEY environment variable not set.\n"
- " Please obtain an API key from Google Cloud and set the environment variable.\n"
- " For example: export GEMINI_API_KEY='YOUR_API_KEY'"
- )
- return
- diff = get_staged_diff(amend=args.amend)
- if diff is None:
- print("Aborting commit due to error getting diff.")
- return
- if not diff.strip():
- print("Aborting: No changes staged to commit.")
- return
- message = generate_commit_message(diff, gemini_api_key)
- if message is None:
- print("Aborting commit due to error generating message.")
- return
- print(f"Generated commit message:\n{message}")
-
- action = "amend the last commit" if args.amend else "create a new commit"
- user_input = input(f"Do you want to {action} with this message? (y/n): ").lower()
- if user_input == "y":
-
- if create_commit(message, amend=args.amend):
- print(f"Commit {'amended' if args.amend else 'created'} successfully.")
- else:
- print(f"Commit {'amendment' if args.amend else 'creation'} failed.")
- else:
- print(f"Commit {'amendment' if args.amend else 'creation'} aborted by user.")
- if __name__ == "__main__":
- main()
|