浏览代码

Fix: Correctly pass arguments to Docker and add amend option

The ai-commit.sh script now correctly passes arguments to the
Docker container, allowing for more flexible usage. This was
achieved by appending `$@` to the docker run command.

The git_commit_ai.py script now supports amending the previous commit
using the `--amend` flag.  A new `argparse` dependency has been
added to handle command-line arguments. The `get_staged_diff`
function has been updated to retrieve the diff between HEAD~1 and
staged files when amending. The `create_commit` function and the
`main` function now correctly handle the amend flag and communicate
with the user accordingly. The number of allowed file requests from
the LLM was also increased from 3 to 5.
seno 1 周之前
父节点
当前提交
21deffe717
共有 2 个文件被更改,包括 68 次插入29 次删除
  1. 5 5
      ai-commit.sh
  2. 63 24
      git_commit_ai.py

+ 5 - 5
ai-commit.sh

@@ -2,13 +2,13 @@
 
 # Check for required Git configuration files
 if [ ! -f "$HOME/.gitconfig" ]; then
-    >&2 echo "Error: Git configuration file not found at $HOME/.gitconfig"
-    exit 1
+	>&2 echo "Error: Git configuration file not found at $HOME/.gitconfig"
+	exit 1
 fi
 
 if [ ! -f "$HOME/.git-credentials" ]; then
-    >&2 echo "Error: Git credentials file not found at $HOME/.git-credentials"
-    exit 1
+	>&2 echo "Error: Git credentials file not found at $HOME/.git-credentials"
+	exit 1
 fi
 
 docker run --rm -it \
@@ -17,4 +17,4 @@ docker run --rm -it \
 	-v "$HOME/.git-credentials:/home/appuser/.git-credentials:ro" \
 	-e GEMINI_API_KEY="$GEMINI_API_KEY" \
 	-u "$(id -u):$(id -g)" \
-	docker.senomas.com/commit:1.0
+	docker.senomas.com/commit:1.0 $@

+ 63 - 24
git_commit_ai.py

@@ -2,9 +2,10 @@ import subprocess
 import os
 import google.generativeai as genai
 import re
+import argparse  # Add argparse import
 
 
-def get_staged_diff():
+def get_staged_diff(amend=False):
     """
     Retrieves the diff of staged files using git.
 
@@ -13,12 +14,29 @@ def get_staged_diff():
     """
     try:
         # Use subprocess.run for better control and error handling
-        process = subprocess.run(
-            ["git", "diff", "--staged"],  # Corrected: --staged is the correct option
-            capture_output=True,
-            text=True,  # Ensure output is returned as text
-            check=True,  # Raise an exception for non-zero exit codes
-        )
+        if amend:
+            process = subprocess.run(
+                [
+                    "git",
+                    "diff",
+                    "HEAD~1",
+                    "--staged",
+                ],  # Corrected: --staged is the correct option
+                capture_output=True,
+                text=True,  # Ensure output is returned as text
+                check=True,  # Raise an exception for non-zero exit codes
+            )
+        else:
+            process = subprocess.run(
+                [
+                    "git",
+                    "diff",
+                    "--staged",
+                ],  # Corrected: --staged is the correct option
+                capture_output=True,
+                text=True,  # Ensure output is returned as text
+                check=True,  # Raise an exception for non-zero exit codes
+            )
         return process.stdout
     except subprocess.CalledProcessError as e:
         print(f"Error getting staged diff: {e}")
@@ -139,7 +157,7 @@ def generate_commit_message(diff, gemini_api_key):
 
         # Use a conversation history for potential back-and-forth
         conversation = [formatted_prompt]
-        max_requests = 3  # Limit the number of file requests
+        max_requests = 5  # Limit the number of file requests
         requests_made = 0
 
         while requests_made < max_requests:
@@ -223,9 +241,9 @@ def generate_commit_message(diff, gemini_api_key):
         return None
 
 
-def create_commit(message):
+def create_commit(message, amend=False):  # Add amend parameter
     """
-    Creates a git commit with the given message.
+    Creates a git commit with the given message, optionally amending the previous commit.
 
     Args:
         message (str): The commit message.
@@ -234,12 +252,18 @@ def create_commit(message):
         bool: True if the commit was successful, False otherwise.
     """
     if not message:
-        print("Error: No commit message provided to create commit.")
+        print("Error: No commit message provided.")
         return False
 
     try:
+        # Build the command list
+        command = ["git", "commit"]
+        if amend:
+            command.append("--amend")
+        command.extend(["-m", message])
+
         process = subprocess.run(
-            ["git", "commit", "-m", message],
+            command,  # Use the dynamically built command
             check=True,  # Important: Raise exception on non-zero exit
             capture_output=True,  # capture the output
             text=True,
@@ -261,10 +285,24 @@ def create_commit(message):
 def main():
     """
     Main function to orchestrate the process of:
-    1. Getting the staged diff.
-    2. Generating a commit message using Gemini.
-    3. Creating a git commit with the generated message.
+    1. Parsing arguments (for --amend).
+    2. Getting the staged diff.
+    3. Generating a commit message using Gemini.
+    4. Creating or amending a git commit with the generated message.
     """
+    # --- Argument Parsing ---
+    parser = argparse.ArgumentParser(
+        description="Generate Git commit messages using AI."
+    )
+    parser.add_argument(
+        "-a",
+        "--amend",
+        action="store_true",
+        help="Amend the previous commit instead of creating a new one.",
+    )
+    args = parser.parse_args()
+    # --- ---
+
     gemini_api_key = os.environ.get("GEMINI_API_KEY")
     if not gemini_api_key:
         print(
@@ -274,7 +312,7 @@ def main():
         )
         return
 
-    diff = get_staged_diff()
+    diff = get_staged_diff(amend=args.amend)
     if diff is None:
         print("Aborting commit due to error getting diff.")
         return  # Exit the script
@@ -290,17 +328,18 @@ def main():
 
     print(f"Generated commit message:\n{message}")  # Print the message for review
 
-    # Prompt the user for confirmation before committing
-    user_input = input(
-        "Do you want to create the commit with this message? (y/n): "
-    ).lower()
+    # --- Confirmation ---
+    action = "amend the last commit" if args.amend else "create a new commit"
+    user_input = input(f"Do you want to {action} with this message? (y/n): ").lower()
+
     if user_input == "y":
-        if create_commit(message):
-            print("Commit created successfully.")
+        # Pass the amend flag to create_commit
+        if create_commit(message, amend=args.amend):
+            print(f"Commit {'amended' if args.amend else 'created'} successfully.")
         else:
-            print("Commit failed.")
+            print(f"Commit {'amendment' if args.amend else 'creation'} failed.")
     else:
-        print("Commit aborted by user.")
+        print(f"Commit {'amendment' if args.amend else 'creation'} aborted by user.")
 
 
 if __name__ == "__main__":