git_commit_ai.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import subprocess
  2. import os
  3. import google.generativeai as genai
  4. import re
  5. import argparse # Add argparse import
  6. import logging
  7. # --- Configuration ---
  8. # Configure logging
  9. logging.basicConfig(level=logging.WARN, format="%(levelname)s: %(message)s")
  10. def get_staged_diff(amend=False):
  11. """
  12. Retrieves the diff of staged files using git.
  13. Returns:
  14. str: The diff of the staged files, or None on error.
  15. """
  16. try:
  17. # Use subprocess.run for better control and error handling
  18. if amend:
  19. process = subprocess.run(
  20. [
  21. "git",
  22. "diff",
  23. "HEAD~1",
  24. "--staged",
  25. ], # Corrected: --staged is the correct option
  26. capture_output=True,
  27. text=True, # Ensure output is returned as text
  28. check=True, # Raise an exception for non-zero exit codes
  29. )
  30. else:
  31. process = subprocess.run(
  32. [
  33. "git",
  34. "diff",
  35. "--staged",
  36. ], # Corrected: --staged is the correct option
  37. capture_output=True,
  38. text=True, # Ensure output is returned as text
  39. check=True, # Raise an exception for non-zero exit codes
  40. )
  41. return process.stdout
  42. except subprocess.CalledProcessError as e:
  43. print(f"Error getting staged diff: {e}")
  44. print(f" stderr: {e.stderr}") # Print stderr for more details
  45. return None
  46. except FileNotFoundError:
  47. print(
  48. "Error: git command not found. Please ensure Git is installed and in your PATH."
  49. )
  50. return None
  51. except Exception as e:
  52. print(f"An unexpected error occurred: {e}")
  53. return None
  54. def get_project_files():
  55. """Gets a list of all files tracked in the latest commit (HEAD)."""
  56. try:
  57. process = subprocess.run(
  58. # Changed command to list files in the last commit
  59. ["git", "ls-tree", "-r", "--name-only", "HEAD"],
  60. capture_output=True,
  61. text=True,
  62. check=True,
  63. cwd=os.getcwd(), # Ensure it runs in the correct directory
  64. )
  65. return process.stdout.splitlines()
  66. except subprocess.CalledProcessError as e:
  67. print(f"Error getting project file list: {e}")
  68. print(f" stderr: {e.stderr}")
  69. return [] # Return empty list on error
  70. except FileNotFoundError:
  71. print("Error: git command not found. Is Git installed and in your PATH?")
  72. return []
  73. except Exception as e:
  74. print(f"An unexpected error occurred while listing files: {e}")
  75. return []
  76. def get_file_content(filepath):
  77. """Reads the content of a file relative to the script's CWD."""
  78. # Consider adding checks to prevent reading files outside the repo
  79. try:
  80. # Assuming the script runs from the repo root
  81. with open(filepath, "r", encoding="utf-8") as f:
  82. return f.read()
  83. except FileNotFoundError:
  84. print(f"Warning: File not found: {filepath}")
  85. return None
  86. except IsADirectoryError:
  87. print(f"Warning: Path is a directory, not a file: {filepath}")
  88. return None
  89. except Exception as e:
  90. print(f"Warning: Error reading file {filepath}: {e}")
  91. return None
  92. def generate_commit_message(diff, gemini_api_key):
  93. """
  94. Generates a commit message using the Gemini API, given the diff.
  95. Args:
  96. diff (str): The diff of the staged files.
  97. gemini_api_key (str): Your Gemini API key.
  98. Returns:
  99. str: The generated commit message, or None on error.
  100. """
  101. if not diff:
  102. print("Error: No diff provided to generate commit message.")
  103. return None
  104. genai.configure(api_key=gemini_api_key)
  105. MODEL_NAME = os.getenv("GEMINI_MODEL")
  106. if not MODEL_NAME:
  107. logging.error("GEMINI_MODEL environment variable not set.")
  108. logging.error(
  109. "Please set the desired Gemini model name (e.g., 'gemini-1.5-flash-latest')."
  110. )
  111. logging.error(" export GEMINI_MODEL='gemini-1.5-flash-latest' (Linux/macOS)")
  112. logging.error(" set GEMINI_MODEL=gemini-1.5-flash-latest (Windows CMD)")
  113. logging.error(
  114. " $env:GEMINI_MODEL='gemini-1.5-flash-latest' (Windows PowerShell)"
  115. )
  116. sys.exit(1)
  117. model = genai.GenerativeModel(MODEL_NAME)
  118. logging.info(f"Using Gemini model: {MODEL_NAME}")
  119. # Define prompt as a regular string, not f-string, placeholders will be filled by .format()
  120. prompt = """
  121. You are an expert assistant that generates Git commit messages following conventional commit standards.
  122. Analyze the following diff of staged files and generate ONLY the commit message (subject and body) adhering to standard Git conventions.
  123. 1. **Subject Line:** Write a concise, imperative subject line summarizing the change (max 50 characters). Start with a capital letter. Do not end with a period. Use standard commit types like 'feat:', 'fix:', 'refactor:', 'docs:', 'test:', 'chore:', etc.
  124. 2. **Blank Line:** Leave a single blank line between the subject and the body.
  125. 3. **Body:** Write a detailed but precise body explaining the 'what' and 'why' of the changes. Wrap lines at 72 characters. Focus on the motivation for the change and contrast its implementation with the previous behavior. If the change is trivial, the body can be omitted.
  126. **Project Files:**
  127. Here is a list of files in the project:
  128. ```
  129. {project_files_list}
  130. ```
  131. **Contextual Understanding:**
  132. * The diff shows changes in the context of the project files listed above.
  133. * If understanding the relationship between the changed files and other parts of the project is necessary to write an accurate commit message, you may request the content of specific files from the list above.
  134. * To request file content, respond *only* with the exact phrase: `Request content for file: <path/to/file>` where `<path/to/file>` is the relative path from the repository root. Do not add any other text to your response if you are requesting a file.
  135. Diff:
  136. ```diff
  137. {diff}
  138. ```
  139. Generate ONLY the commit message text, without any introductory phrases like "Here is the commit message:", unless you need to request file content.
  140. """
  141. response = None
  142. try:
  143. # Get project files to include in the prompt
  144. project_files = get_project_files()
  145. project_files_list = (
  146. "\n".join(project_files)
  147. if project_files
  148. else "(Could not list project files)"
  149. )
  150. # Format the prompt with the diff and file list
  151. formatted_prompt = prompt.format(
  152. diff=diff, project_files_list=project_files_list
  153. )
  154. # Use a conversation history for potential back-and-forth
  155. conversation = [formatted_prompt]
  156. max_requests = 5 # Limit the number of file requests
  157. requests_made = 0
  158. while requests_made < max_requests:
  159. response = model.generate_content("\n".join(conversation))
  160. message = response.text.strip()
  161. # Check if the AI is requesting a file
  162. request_match = re.match(r"^Request content for file: (.*)$", message)
  163. if request_match:
  164. filepath = request_match.group(1).strip()
  165. print(f"AI requests content for: {filepath}")
  166. user_input = input(f"Allow access to '{filepath}'? (y/n): ").lower()
  167. if user_input == "y":
  168. file_content = get_file_content(filepath)
  169. if file_content:
  170. # Provide content to AI
  171. conversation.append(
  172. f"Response for file '{filepath}':\n```\n{file_content}\n```\nNow, generate the commit message based on the diff and this context."
  173. )
  174. else:
  175. # Inform AI file couldn't be read
  176. conversation.append(
  177. f"File '{filepath}' could not be read or was not found. Continue generating the commit message based on the original diff."
  178. )
  179. else:
  180. # Inform AI permission denied
  181. conversation.append(
  182. f"User denied access to file '{filepath}'. Continue generating the commit message based on the original diff."
  183. )
  184. requests_made += 1
  185. else:
  186. # AI did not request a file, assume it's the commit message
  187. break # Exit the loop
  188. else:
  189. # Max requests reached
  190. print(
  191. "Warning: Maximum number of file requests reached. Generating commit message without further context."
  192. )
  193. # Make one last attempt to generate the message without the last request fulfilled
  194. response = model.generate_content(
  195. "\n".join(conversation[:-1])
  196. + "\nGenerate the commit message now based on the available information."
  197. ) # Use conversation up to the last request
  198. message = response.text.strip()
  199. # Extract the final message, remove potential markdown code blocks, and strip whitespace
  200. # Ensure message is not None before processing
  201. if message:
  202. message = re.sub(
  203. r"^\s*```[a-zA-Z]*\s*\n?", "", message, flags=re.MULTILINE
  204. ) # Remove leading code block start
  205. message = re.sub(
  206. r"\n?```\s*$", "", message, flags=re.MULTILINE
  207. ) # Remove trailing code block end
  208. message = message.strip() # Strip leading/trailing whitespace
  209. else:
  210. # Handle case where response.text might be None or empty after failed requests
  211. print(
  212. "Error: Failed to get a valid response from the AI after handling requests."
  213. )
  214. return None
  215. # Basic validation: Check if the message seems plausible (not empty, etc.)
  216. if not message or len(message) < 5: # Arbitrary short length check
  217. print(
  218. f"Warning: Generated commit message seems too short or empty: '{message}'"
  219. )
  220. # Optionally, you could add retry logic here or return None
  221. return message
  222. except Exception as e:
  223. # Provide more context in the error message
  224. print(f"Error generating commit message with Gemini: {e}")
  225. # Consider logging response details if available, e.g., response.prompt_feedback
  226. if hasattr(response, "prompt_feedback"):
  227. print(f"Prompt Feedback: {response.prompt_feedback}")
  228. return None
  229. def create_commit(message, amend=False): # Add amend parameter
  230. """
  231. Creates a git commit with the given message, optionally amending the previous commit.
  232. Args:
  233. message (str): The commit message.
  234. Returns:
  235. bool: True if the commit was successful, False otherwise.
  236. """
  237. if not message:
  238. print("Error: No commit message provided.")
  239. return False
  240. try:
  241. # Build the command list
  242. command = ["git", "commit"]
  243. if amend:
  244. command.append("--amend")
  245. command.extend(["-m", message])
  246. process = subprocess.run(
  247. command, # Use the dynamically built command
  248. check=True, # Important: Raise exception on non-zero exit
  249. capture_output=True, # capture the output
  250. text=True,
  251. )
  252. print(process.stdout) # print the output
  253. return True
  254. except subprocess.CalledProcessError as e:
  255. print(f"Error creating git commit: {e}")
  256. print(e.stderr)
  257. return False
  258. except FileNotFoundError:
  259. print("Error: git command not found. Is Git installed and in your PATH?")
  260. return False
  261. except Exception as e:
  262. print(f"An unexpected error occurred: {e}")
  263. return False
  264. def main():
  265. """
  266. Main function to orchestrate the process of:
  267. 1. Parsing arguments (for --amend).
  268. 2. Getting the staged diff.
  269. 3. Generating a commit message using Gemini.
  270. 4. Creating or amending a git commit with the generated message.
  271. """
  272. # --- Argument Parsing ---
  273. parser = argparse.ArgumentParser(
  274. description="Generate Git commit messages using AI."
  275. )
  276. parser.add_argument(
  277. "-a",
  278. "--amend",
  279. action="store_true",
  280. help="Amend the previous commit instead of creating a new one.",
  281. )
  282. args = parser.parse_args()
  283. # --- ---
  284. gemini_api_key = os.environ.get("GEMINI_API_KEY")
  285. if not gemini_api_key:
  286. print(
  287. "Error: GEMINI_API_KEY environment variable not set.\n"
  288. " Please obtain an API key from Google Cloud and set the environment variable.\n"
  289. " For example: export GEMINI_API_KEY='YOUR_API_KEY'"
  290. )
  291. return
  292. diff = get_staged_diff(amend=args.amend)
  293. if diff is None:
  294. print("Aborting commit due to error getting diff.")
  295. return # Exit the script
  296. if not diff.strip(): # check if the diff is empty
  297. print("Aborting: No changes staged to commit.")
  298. return
  299. message = generate_commit_message(diff, gemini_api_key)
  300. if message is None:
  301. print("Aborting commit due to error generating message.")
  302. return # Exit if message generation failed
  303. print(f"Generated commit message:\n{message}") # Print the message for review
  304. # --- Confirmation ---
  305. action = "amend the last commit" if args.amend else "create a new commit"
  306. user_input = input(f"Do you want to {action} with this message? (y/n): ").lower()
  307. if user_input == "y":
  308. # Pass the amend flag to create_commit
  309. if create_commit(message, amend=args.amend):
  310. print(f"Commit {'amended' if args.amend else 'created'} successfully.")
  311. else:
  312. print(f"Commit {'amendment' if args.amend else 'creation'} failed.")
  313. else:
  314. print(f"Commit {'amendment' if args.amend else 'creation'} aborted by user.")
  315. if __name__ == "__main__":
  316. main()