|
@@ -2,9 +2,10 @@ import subprocess
|
|
import os
|
|
import os
|
|
import google.generativeai as genai
|
|
import google.generativeai as genai
|
|
import re
|
|
import re
|
|
|
|
+import argparse # Add argparse import
|
|
|
|
|
|
|
|
|
|
-def get_staged_diff():
|
|
|
|
|
|
+def get_staged_diff(amend=False):
|
|
"""
|
|
"""
|
|
Retrieves the diff of staged files using git.
|
|
Retrieves the diff of staged files using git.
|
|
|
|
|
|
@@ -13,12 +14,29 @@ def get_staged_diff():
|
|
"""
|
|
"""
|
|
try:
|
|
try:
|
|
# Use subprocess.run for better control and error handling
|
|
# Use subprocess.run for better control and error handling
|
|
- process = subprocess.run(
|
|
|
|
- ["git", "diff", "--staged"], # Corrected: --staged is the correct option
|
|
|
|
- capture_output=True,
|
|
|
|
- text=True, # Ensure output is returned as text
|
|
|
|
- check=True, # Raise an exception for non-zero exit codes
|
|
|
|
- )
|
|
|
|
|
|
+ if amend:
|
|
|
|
+ process = subprocess.run(
|
|
|
|
+ [
|
|
|
|
+ "git",
|
|
|
|
+ "diff",
|
|
|
|
+ "HEAD~1",
|
|
|
|
+ "--staged",
|
|
|
|
+ ], # Corrected: --staged is the correct option
|
|
|
|
+ capture_output=True,
|
|
|
|
+ text=True, # Ensure output is returned as text
|
|
|
|
+ check=True, # Raise an exception for non-zero exit codes
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ process = subprocess.run(
|
|
|
|
+ [
|
|
|
|
+ "git",
|
|
|
|
+ "diff",
|
|
|
|
+ "--staged",
|
|
|
|
+ ], # Corrected: --staged is the correct option
|
|
|
|
+ capture_output=True,
|
|
|
|
+ text=True, # Ensure output is returned as text
|
|
|
|
+ check=True, # Raise an exception for non-zero exit codes
|
|
|
|
+ )
|
|
return process.stdout
|
|
return process.stdout
|
|
except subprocess.CalledProcessError as e:
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error getting staged diff: {e}")
|
|
print(f"Error getting staged diff: {e}")
|
|
@@ -139,7 +157,7 @@ def generate_commit_message(diff, gemini_api_key):
|
|
|
|
|
|
# Use a conversation history for potential back-and-forth
|
|
# Use a conversation history for potential back-and-forth
|
|
conversation = [formatted_prompt]
|
|
conversation = [formatted_prompt]
|
|
- max_requests = 3 # Limit the number of file requests
|
|
|
|
|
|
+ max_requests = 5 # Limit the number of file requests
|
|
requests_made = 0
|
|
requests_made = 0
|
|
|
|
|
|
while requests_made < max_requests:
|
|
while requests_made < max_requests:
|
|
@@ -223,9 +241,9 @@ def generate_commit_message(diff, gemini_api_key):
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
-def create_commit(message):
|
|
|
|
|
|
+def create_commit(message, amend=False): # Add amend parameter
|
|
"""
|
|
"""
|
|
- Creates a git commit with the given message.
|
|
|
|
|
|
+ Creates a git commit with the given message, optionally amending the previous commit.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
message (str): The commit message.
|
|
message (str): The commit message.
|
|
@@ -234,12 +252,18 @@ def create_commit(message):
|
|
bool: True if the commit was successful, False otherwise.
|
|
bool: True if the commit was successful, False otherwise.
|
|
"""
|
|
"""
|
|
if not message:
|
|
if not message:
|
|
- print("Error: No commit message provided to create commit.")
|
|
|
|
|
|
+ print("Error: No commit message provided.")
|
|
return False
|
|
return False
|
|
|
|
|
|
try:
|
|
try:
|
|
|
|
+ # Build the command list
|
|
|
|
+ command = ["git", "commit"]
|
|
|
|
+ if amend:
|
|
|
|
+ command.append("--amend")
|
|
|
|
+ command.extend(["-m", message])
|
|
|
|
+
|
|
process = subprocess.run(
|
|
process = subprocess.run(
|
|
- ["git", "commit", "-m", message],
|
|
|
|
|
|
+ command, # Use the dynamically built command
|
|
check=True, # Important: Raise exception on non-zero exit
|
|
check=True, # Important: Raise exception on non-zero exit
|
|
capture_output=True, # capture the output
|
|
capture_output=True, # capture the output
|
|
text=True,
|
|
text=True,
|
|
@@ -261,10 +285,24 @@ def create_commit(message):
|
|
def main():
|
|
def main():
|
|
"""
|
|
"""
|
|
Main function to orchestrate the process of:
|
|
Main function to orchestrate the process of:
|
|
- 1. Getting the staged diff.
|
|
|
|
- 2. Generating a commit message using Gemini.
|
|
|
|
- 3. Creating a git commit with the generated message.
|
|
|
|
|
|
+ 1. Parsing arguments (for --amend).
|
|
|
|
+ 2. Getting the staged diff.
|
|
|
|
+ 3. Generating a commit message using Gemini.
|
|
|
|
+ 4. Creating or amending a git commit with the generated message.
|
|
"""
|
|
"""
|
|
|
|
+ # --- Argument Parsing ---
|
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
|
+ description="Generate Git commit messages using AI."
|
|
|
|
+ )
|
|
|
|
+ parser.add_argument(
|
|
|
|
+ "-a",
|
|
|
|
+ "--amend",
|
|
|
|
+ action="store_true",
|
|
|
|
+ help="Amend the previous commit instead of creating a new one.",
|
|
|
|
+ )
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
+ # --- ---
|
|
|
|
+
|
|
gemini_api_key = os.environ.get("GEMINI_API_KEY")
|
|
gemini_api_key = os.environ.get("GEMINI_API_KEY")
|
|
if not gemini_api_key:
|
|
if not gemini_api_key:
|
|
print(
|
|
print(
|
|
@@ -274,7 +312,7 @@ def main():
|
|
)
|
|
)
|
|
return
|
|
return
|
|
|
|
|
|
- diff = get_staged_diff()
|
|
|
|
|
|
+ diff = get_staged_diff(amend=args.amend)
|
|
if diff is None:
|
|
if diff is None:
|
|
print("Aborting commit due to error getting diff.")
|
|
print("Aborting commit due to error getting diff.")
|
|
return # Exit the script
|
|
return # Exit the script
|
|
@@ -290,17 +328,18 @@ def main():
|
|
|
|
|
|
print(f"Generated commit message:\n{message}") # Print the message for review
|
|
print(f"Generated commit message:\n{message}") # Print the message for review
|
|
|
|
|
|
- # Prompt the user for confirmation before committing
|
|
|
|
- user_input = input(
|
|
|
|
- "Do you want to create the commit with this message? (y/n): "
|
|
|
|
- ).lower()
|
|
|
|
|
|
+ # --- Confirmation ---
|
|
|
|
+ action = "amend the last commit" if args.amend else "create a new commit"
|
|
|
|
+ user_input = input(f"Do you want to {action} with this message? (y/n): ").lower()
|
|
|
|
+
|
|
if user_input == "y":
|
|
if user_input == "y":
|
|
- if create_commit(message):
|
|
|
|
- print("Commit created successfully.")
|
|
|
|
|
|
+ # Pass the amend flag to create_commit
|
|
|
|
+ if create_commit(message, amend=args.amend):
|
|
|
|
+ print(f"Commit {'amended' if args.amend else 'created'} successfully.")
|
|
else:
|
|
else:
|
|
- print("Commit failed.")
|
|
|
|
|
|
+ print(f"Commit {'amendment' if args.amend else 'creation'} failed.")
|
|
else:
|
|
else:
|
|
- print("Commit aborted by user.")
|
|
|
|
|
|
+ print(f"Commit {'amendment' if args.amend else 'creation'} aborted by user.")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|