Skip to content

Commit 5535c96

Browse files
committedOct 27, 2024
Ensure diff mode works
1 parent 3b06a85 commit 5535c96

File tree

4 files changed

+100
-41
lines changed

4 files changed

+100
-41
lines changed
 

‎src/copychat/cli.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pyperclip
66
from enum import Enum
77

8-
from .core import scan_directory, DiffMode
8+
from .core import scan_directory, DiffMode, get_file_content
99
from .format import (
1010
estimate_tokens,
1111
format_files as format_files_xml,
@@ -38,6 +38,8 @@ def parse_source(source: str) -> tuple[SourceType, str]:
3838
def diff_mode_callback(value: str) -> DiffMode:
3939
"""Convert string value to DiffMode enum."""
4040
try:
41+
if isinstance(value, DiffMode):
42+
return value
4143
return DiffMode(value)
4244
except ValueError:
4345
valid_values = [mode.value for mode in DiffMode]
@@ -94,8 +96,8 @@ def main(
9496
"-x",
9597
help="Glob patterns to exclude",
9698
),
97-
diff_mode: DiffMode = typer.Option(
98-
DiffMode.FULL.value,
99+
diff_mode: str = typer.Option(
100+
"full", # Pass the string value instead of enum
99101
"--diff-mode",
100102
"-d",
101103
help="How to handle git diffs",
@@ -127,7 +129,8 @@ def main(
127129

128130
# Handle file vs directory source
129131
if source_dir.is_file():
130-
all_files = {source_dir: None} # Use None as placeholder for git info
132+
content = get_file_content(source_dir, diff_mode)
133+
all_files = {source_dir: content} if content is not None else {}
131134
else:
132135
# For directories, scan all paths
133136
if not paths:
@@ -138,11 +141,13 @@ def main(
138141
for path in paths:
139142
target = source_dir / path if source_dir != Path(".") else Path(path)
140143
if target.is_file():
141-
all_files[target] = None
144+
content = get_file_content(target, diff_mode)
145+
if content is not None:
146+
all_files[target] = content
142147
else:
143148
files = scan_directory(
144149
target,
145-
include=include,
150+
include=include.split(",") if include else None,
146151
exclude_patterns=exclude,
147152
diff_mode=diff_mode,
148153
)
@@ -152,8 +157,10 @@ def main(
152157
error_console.print("[yellow]No matching files found[/]")
153158
raise typer.Exit(1)
154159

155-
# Format files
156-
result = format_files_xml(list(all_files.keys()))
160+
# Format files - pass both paths and content
161+
result = format_files_xml(
162+
[(path, content) for path, content in all_files.items()]
163+
)
157164

158165
# Handle outputs
159166
if outfile:

‎src/copychat/core.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,26 @@ def get_gitignore_spec(
5454
def get_git_diff(path: Path) -> str:
5555
"""Get git diff for the given path."""
5656
try:
57+
# First check if file is tracked by git
5758
result = subprocess.run(
58-
["git", "diff", str(path)],
59+
["git", "ls-files", "--error-unmatch", str(path)],
5960
capture_output=True,
6061
text=True,
61-
check=True,
62+
check=False, # Don't raise error for untracked files
6263
)
63-
return result.stdout
64+
if result.returncode != 0:
65+
return "" # File is not tracked by git
66+
67+
# Get the diff for tracked files
68+
result = subprocess.run(
69+
["git", "diff", "--exit-code", str(path)],
70+
capture_output=True,
71+
text=True,
72+
check=False, # Don't raise error for no changes
73+
)
74+
# exit-code 0 means no changes, 1 means changes present
75+
return result.stdout if result.returncode == 1 else ""
76+
6477
except subprocess.CalledProcessError:
6578
return ""
6679

@@ -70,32 +83,36 @@ def get_file_content(path: Path, diff_mode: DiffMode) -> Optional[str]:
7083
if not path.is_file():
7184
return None
7285

73-
if diff_mode == DiffMode.FULL:
74-
return path.read_text()
75-
86+
# Get content and diff
87+
content = path.read_text()
7688
diff = get_git_diff(path)
77-
if not diff and diff_mode in (DiffMode.CHANGED_WITH_DIFF, DiffMode.DIFF_ONLY):
78-
return None
7989

80-
if diff_mode == DiffMode.DIFF_ONLY:
81-
return diff
82-
83-
content = path.read_text()
84-
return f"{content}\n\n# Git Diff:\n{diff}" if diff else content
90+
# Handle different modes
91+
if diff_mode == DiffMode.FULL:
92+
return content
93+
elif diff_mode == DiffMode.FULL_WITH_DIFF:
94+
return f"{content}\n\n# Git Diff:\n{diff}" if diff else content
95+
elif diff_mode == DiffMode.CHANGED_WITH_DIFF:
96+
return f"{content}\n\n# Git Diff:\n{diff}" if diff else None
97+
elif diff_mode == DiffMode.DIFF_ONLY:
98+
return diff if diff else None
99+
else:
100+
return None # Shouldn't reach here, but makes mypy happy
85101

86102

87103
def scan_directory(
88104
path: Path,
89105
include: Optional[list[str]] = None,
90106
exclude_patterns: Optional[list[str]] = None,
91107
diff_mode: DiffMode = DiffMode.FULL,
92-
) -> dict[Path, Optional[str]]:
108+
) -> dict[Path, str]:
93109
"""Scan directory for files to process."""
94110
if path.is_file():
95111
# For single files, just check if it matches filters
96112
if include and path.suffix.lstrip(".") not in include:
97113
return {}
98-
return {path: None}
114+
content = get_file_content(path, diff_mode)
115+
return {path: content} if content is not None else {}
99116

100117
# Convert to absolute path first
101118
abs_path = path.absolute()
@@ -131,6 +148,7 @@ def scan_directory(
131148
if ext not in include_set:
132149
continue
133150

151+
# Get content based on diff mode
134152
content = get_file_content(file_path, diff_mode)
135153
if content is not None:
136154
result[file_path] = content

‎src/copychat/format.py

+17-18
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@ def guess_language(file_path: Path) -> Optional[str]:
4646
return language_map.get(ext)
4747

4848

49-
def format_file(
50-
file_path: Path, root_path: Path, add_line_numbers: bool = False
51-
) -> str:
52-
"""Format a single file as XML-style markdown with line numbers."""
49+
def format_file(file_path: Path, root_path: Path, content: Optional[str] = None) -> str:
50+
"""Format a single file as XML-style markdown."""
5351
try:
54-
content = file_path.read_text()
52+
# Use provided content or read from file
53+
if content is None:
54+
content = file_path.read_text()
55+
5556
# Use string paths for comparison to handle symlinks and different path formats
5657
file_str = str(file_path.absolute())
5758
root_str = str(root_path)
@@ -67,13 +68,6 @@ def format_file(
6768

6869
attrs_str = " ".join(tag_attrs)
6970

70-
# Add line numbers to content
71-
if add_line_numbers:
72-
numbered_lines = []
73-
for i, line in enumerate(content.splitlines(), 1):
74-
numbered_lines.append(f"{i}| {line}")
75-
content = "\n".join(numbered_lines)
76-
7771
return f"""<file {attrs_str}>
7872
{content}
7973
</file>"""
@@ -110,21 +104,26 @@ def estimate_tokens(text: str) -> int:
110104
return len(text) // 4 # Rough estimate: ~4 chars per token
111105

112106

113-
def format_files(files: list[Path]) -> str:
114-
"""Format files into markdown with XML-style tags."""
107+
def format_files(files: list[tuple[Path, str]]) -> str:
108+
"""Format files into markdown with XML-style tags.
109+
110+
Args:
111+
files: List of (path, content) tuples to format
112+
"""
115113
if not files:
116114
return "<!-- No files found matching criteria -->\n"
117115

118116
# Find common root path using os.path.commonpath
119-
str_paths = [str(f.absolute()) for f in files]
117+
paths = [f[0] for f in files]
118+
str_paths = [str(f.absolute()) for f in paths]
120119
root_path = Path(commonpath(str_paths))
121120

122121
# Create header
123-
result = [create_header(files, root_path)]
122+
result = [create_header(paths, root_path)]
124123

125124
# Format each file
126-
for file_path in files:
127-
result.append(format_file(file_path, root_path))
125+
for file_path, content in files:
126+
result.append(format_file(file_path, root_path, content))
128127

129128
final_result = "\n".join(result)
130129
char_count = len(final_result)

‎uv.lock

+35
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
Please sign in to comment.