Skip to content

Commit 659aa0f

Browse files
authoredOct 12, 2024··
Merge pull request #438 from numtide/feat/improve-paths-behaviour
Improve path argument handling
2 parents d25cd46 + 6bfe249 commit 659aa0f

File tree

13 files changed

+393
-212
lines changed

13 files changed

+393
-212
lines changed
 

‎.github/workflows/golangci-lint.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ jobs:
2020
- name: golangci-lint
2121
uses: golangci/golangci-lint-action@v6
2222
with:
23-
version: v1.59.1
23+
version: v1.61.0
2424
args: --timeout=2m

‎cmd/format/format.go

+20-61
Original file line numberDiff line numberDiff line change
@@ -63,49 +63,6 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string)
6363
<-time.After(time.Until(startAfter))
6464
}
6565

66-
if cfg.Stdin {
67-
// check we have only received one path arg which we use for the file extension / matching to formatters
68-
if len(paths) != 1 {
69-
return fmt.Errorf("exactly one path should be specified when using the --stdin flag")
70-
}
71-
72-
// read stdin into a temporary file with the same file extension
73-
pattern := fmt.Sprintf("*%s", filepath.Ext(paths[0]))
74-
75-
file, err := os.CreateTemp("", pattern)
76-
if err != nil {
77-
return fmt.Errorf("failed to create a temporary file for processing stdin: %w", err)
78-
}
79-
80-
if _, err = io.Copy(file, os.Stdin); err != nil {
81-
return fmt.Errorf("failed to copy stdin into a temporary file")
82-
}
83-
84-
// set the tree root to match the temp directory
85-
cfg.TreeRoot, err = filepath.Abs(filepath.Dir(file.Name()))
86-
if err != nil {
87-
return fmt.Errorf("failed to get absolute path for tree root: %w", err)
88-
}
89-
90-
// configure filesystem walker to traverse the temporary tree root
91-
cfg.Walk = "filesystem"
92-
93-
// update paths with temp file
94-
paths[0], err = filepath.Rel(os.TempDir(), file.Name())
95-
if err != nil {
96-
return fmt.Errorf("failed to get relative path for temp file: %w", err)
97-
}
98-
99-
} else {
100-
// checks all paths are contained within the tree root
101-
for _, path := range paths {
102-
rootPath := filepath.Join(cfg.TreeRoot, path)
103-
if _, err = os.Stat(rootPath); err != nil {
104-
return fmt.Errorf("path %s not found within the tree root %s", path, cfg.TreeRoot)
105-
}
106-
}
107-
}
108-
10966
// cpu profiling
11067
if cfg.CpuProfile != "" {
11168
cpuProfile, err := os.Create(cfg.CpuProfile)
@@ -204,13 +161,29 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string)
204161
eg.Go(postProcessing(ctx, cfg, statz, formattedCh))
205162
eg.Go(applyFormatters(ctx, cfg, statz, globalExcludes, formatters, filesCh, formattedCh))
206163

207-
//
164+
// parse the walk type
208165
walkType, err := walk.TypeString(cfg.Walk)
209166
if err != nil {
210167
return fmt.Errorf("invalid walk type: %w", err)
211168
}
212169

213-
reader, err := walk.NewReader(walkType, cfg.TreeRoot, paths, db, statz)
170+
if walkType == walk.Stdin {
171+
// check we have only received one path arg which we use for the file extension / matching to formatters
172+
if len(paths) != 1 {
173+
return fmt.Errorf("exactly one path should be specified when using the --stdin flag")
174+
}
175+
} else {
176+
// checks all paths are contained within the tree root
177+
for _, path := range paths {
178+
rootPath := filepath.Join(cfg.TreeRoot, path)
179+
if _, err = os.Stat(rootPath); err != nil {
180+
return fmt.Errorf("path %s not found within the tree root %s", path, cfg.TreeRoot)
181+
}
182+
}
183+
}
184+
185+
// create a new reader for traversing the paths
186+
reader, err := walk.NewCompositeReader(walkType, cfg.TreeRoot, paths, db, statz)
214187
if err != nil {
215188
return fmt.Errorf("failed to create walker: %w", err)
216189
}
@@ -440,22 +413,8 @@ func postProcessing(
440413
file.Info = newInfo
441414
}
442415

443-
if file.Release != nil {
444-
file.Release()
445-
}
446-
447-
if cfg.Stdin {
448-
// dump file into stdout
449-
f, err := os.Open(file.Path)
450-
if err != nil {
451-
return fmt.Errorf("failed to open %s: %w", file.Path, err)
452-
}
453-
if _, err = io.Copy(os.Stdout, f); err != nil {
454-
return fmt.Errorf("failed to copy %s to stdout: %w", file.Path, err)
455-
}
456-
if err = os.Remove(f.Name()); err != nil {
457-
return fmt.Errorf("failed to remove temp file %s: %w", file.Path, err)
458-
}
416+
if err := file.Release(); err != nil {
417+
return fmt.Errorf("failed to release file: %w", err)
459418
}
460419
}
461420
}

‎cmd/root_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -628,11 +628,11 @@ func TestGitWorktree(t *testing.T) {
628628

629629
_, statz, err = treefmt(t, "-C", tempDir, "-c", "haskell", "foo.txt")
630630
as.NoError(err)
631-
assertStats(t, as, statz, 7, 7, 7, 0)
631+
assertStats(t, as, statz, 8, 8, 8, 0)
632632

633633
_, statz, err = treefmt(t, "-C", tempDir, "-c", "foo.txt")
634634
as.NoError(err)
635-
assertStats(t, as, statz, 0, 0, 0, 0)
635+
assertStats(t, as, statz, 1, 1, 1, 0)
636636
}
637637

638638
func TestPathsArg(t *testing.T) {

‎config/config.go

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"path/filepath"
77
"strings"
88

9+
"github.com/numtide/treefmt/walk"
10+
911
"github.com/spf13/pflag"
1012
"github.com/spf13/viper"
1113
)
@@ -175,6 +177,11 @@ func FromViper(v *viper.Viper) (*Config, error) {
175177
return nil, fmt.Errorf("failed to get absolute path for working directory: %w", err)
176178
}
177179

180+
// if the stdin flag was passed, we force the stdin walk type
181+
if cfg.Stdin {
182+
cfg.Walk = walk.Stdin.String()
183+
}
184+
178185
// determine the tree root
179186
if cfg.TreeRoot == "" {
180187
// if none was specified, we first try with tree-root-file

‎walk/cached.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ func (c *CachedReader) Read(ctx context.Context, files []*File) (n int, err erro
9595
}
9696

9797
// set a release function which inserts this file into the release channel for updating
98-
file.Release = func() {
98+
file.AddReleaseFunc(func() error {
9999
c.releaseCh <- file
100-
}
100+
return nil
101+
})
101102
}
102103

103104
if errors.Is(err, io.EOF) {

‎walk/cached_test.go

+21-10
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ func TestCachedReader(t *testing.T) {
2222
batchSize := 1024
2323
tempDir := test.TempExamples(t)
2424

25-
readAll := func(paths []string) (totalCount, newCount, changeCount int, statz stats.Stats) {
25+
readAll := func(path string) (totalCount, newCount, changeCount int, statz stats.Stats) {
2626
statz = stats.New()
2727

2828
db, err := cache.Open(tempDir)
2929
as.NoError(err)
3030
defer db.Close()
3131

32-
delegate := walk.NewFilesystemReader(tempDir, paths, &statz, batchSize)
32+
delegate := walk.NewFilesystemReader(tempDir, path, &statz, batchSize)
3333
reader, err := walk.NewCachedReader(db, batchSize, delegate)
3434
as.NoError(err)
3535

@@ -49,7 +49,8 @@ func TestCachedReader(t *testing.T) {
4949
} else if file.Cache.HasChanged(file.Info) {
5050
changeCount++
5151
}
52-
file.Release()
52+
53+
as.NoError(file.Release())
5354
}
5455

5556
cancel()
@@ -64,13 +65,13 @@ func TestCachedReader(t *testing.T) {
6465
return totalCount, newCount, changeCount, statz
6566
}
6667

67-
totalCount, newCount, changeCount, _ := readAll([]string{"."})
68+
totalCount, newCount, changeCount, _ := readAll("")
6869
as.Equal(32, totalCount)
6970
as.Equal(32, newCount)
7071
as.Equal(0, changeCount)
7172

7273
// read again, should be no changes
73-
totalCount, newCount, changeCount, _ = readAll([]string{"."})
74+
totalCount, newCount, changeCount, _ = readAll("")
7475
as.Equal(32, totalCount)
7576
as.Equal(0, newCount)
7677
as.Equal(0, changeCount)
@@ -83,7 +84,7 @@ func TestCachedReader(t *testing.T) {
8384
as.NoError(os.Chtimes(filepath.Join(tempDir, "shell/foo.sh"), time.Now(), modTime))
8485
as.NoError(os.Chtimes(filepath.Join(tempDir, "haskell/Nested/Foo.hs"), time.Now(), modTime))
8586

86-
totalCount, newCount, changeCount, _ = readAll([]string{"."})
87+
totalCount, newCount, changeCount, _ = readAll("")
8788
as.Equal(32, totalCount)
8889
as.Equal(0, newCount)
8990
as.Equal(3, changeCount)
@@ -95,7 +96,7 @@ func TestCachedReader(t *testing.T) {
9596
_, err = os.Create(filepath.Join(tempDir, "fizz.go"))
9697
as.NoError(err)
9798

98-
totalCount, newCount, changeCount, _ = readAll([]string{"."})
99+
totalCount, newCount, changeCount, _ = readAll("")
99100
as.Equal(34, totalCount)
100101
as.Equal(2, newCount)
101102
as.Equal(0, changeCount)
@@ -113,14 +114,24 @@ func TestCachedReader(t *testing.T) {
113114
as.NoError(err)
114115
as.NoError(f.Close())
115116

116-
totalCount, newCount, changeCount, _ = readAll([]string{"."})
117+
totalCount, newCount, changeCount, _ = readAll("")
117118
as.Equal(34, totalCount)
118119
as.Equal(0, newCount)
119120
as.Equal(2, changeCount)
120121

121122
// read some paths within the root
122-
totalCount, newCount, changeCount, _ = readAll([]string{"go", "elm/src", "haskell"})
123-
as.Equal(10, totalCount)
123+
totalCount, newCount, changeCount, _ = readAll("go")
124+
as.Equal(2, totalCount)
125+
as.Equal(0, newCount)
126+
as.Equal(0, changeCount)
127+
128+
totalCount, newCount, changeCount, _ = readAll("elm/src")
129+
as.Equal(1, totalCount)
130+
as.Equal(0, newCount)
131+
as.Equal(0, changeCount)
132+
133+
totalCount, newCount, changeCount, _ = readAll("haskell")
134+
as.Equal(7, totalCount)
124135
as.Equal(0, newCount)
125136
as.Equal(0, changeCount)
126137
}

‎walk/filesystem.go

+15-29
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
type FilesystemReader struct {
2020
log *log.Logger
2121
root string
22-
paths []string
22+
path string
2323
batchSize int
2424

2525
eg *errgroup.Group
@@ -35,7 +35,17 @@ func (f *FilesystemReader) process() error {
3535
close(f.filesCh)
3636
}()
3737

38-
walkFn := func(path string, info fs.FileInfo, err error) error {
38+
// f.path is relative to the root, so we create a fully qualified version
39+
// we also clean the path up in case there are any ../../ components etc.
40+
path := filepath.Clean(filepath.Join(f.root, f.path))
41+
42+
// ensure the path is within the root
43+
if !strings.HasPrefix(path, f.root) {
44+
return fmt.Errorf("path '%s' is outside of the root '%s'", path, f.root)
45+
}
46+
47+
// walk the path
48+
return filepath.Walk(path, func(path string, info fs.FileInfo, err error) error {
3949
// return errors immediately
4050
if err != nil {
4151
return err
@@ -64,26 +74,7 @@ func (f *FilesystemReader) process() error {
6474
f.log.Debugf("file queued %s", file.RelPath)
6575

6676
return nil
67-
}
68-
69-
// walk each path specified
70-
for idx := range f.paths {
71-
// f.paths are relative to the root, so we create a fully qualified version
72-
// we also clean the path up in case there are any ../../ components etc.
73-
path := filepath.Clean(filepath.Join(f.root, f.paths[idx]))
74-
75-
// ensure the path is within the root
76-
if !strings.HasPrefix(path, f.root) {
77-
return fmt.Errorf("path '%s' is outside of the root '%s'", path, f.root)
78-
}
79-
80-
// walk the path
81-
if err := filepath.Walk(path, walkFn); err != nil {
82-
return err
83-
}
84-
}
85-
86-
return nil
77+
})
8778
}
8879

8980
// Read populates the provided files array with as many files are available until the provided context is cancelled.
@@ -129,22 +120,17 @@ func (f *FilesystemReader) Close() error {
129120
// and root.
130121
func NewFilesystemReader(
131122
root string,
132-
paths []string,
123+
path string,
133124
statz *stats.Stats,
134125
batchSize int,
135126
) *FilesystemReader {
136-
// if no paths are specified, we default to the root path
137-
if len(paths) == 0 {
138-
paths = []string{"."}
139-
}
140-
141127
// create an error group for managing the processing loop
142128
eg := errgroup.Group{}
143129

144130
r := FilesystemReader{
145131
log: log.WithPrefix("walk[filesystem]"),
146132
root: root,
147-
paths: paths,
133+
path: path,
148134
batchSize: batchSize,
149135

150136
eg: &eg,

‎walk/filesystem_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func TestFilesystemReader(t *testing.T) {
5656
tempDir := test.TempExamples(t)
5757
statz := stats.New()
5858

59-
r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024)
59+
r := walk.NewFilesystemReader(tempDir, "", &statz, 1024)
6060

6161
count := 0
6262

‎walk/git.go

+75-82
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919

2020
type GitReader struct {
2121
root string
22-
paths []string
22+
path string
2323
stats *stats.Stats
2424
batchSize int
2525

@@ -45,103 +45,100 @@ func (g *GitReader) process() error {
4545
// git index into memory for faster lookups
4646
var idxCache *filetree
4747

48-
for pathIdx := range g.paths {
48+
path := filepath.Clean(filepath.Join(g.root, g.path))
49+
if !strings.HasPrefix(path, g.root) {
50+
return fmt.Errorf("path '%s' is outside of the root '%s'", path, g.root)
51+
}
4952

50-
path := filepath.Clean(filepath.Join(g.root, g.paths[pathIdx]))
51-
if !strings.HasPrefix(path, g.root) {
52-
return fmt.Errorf("path '%s' is outside of the root '%s'", path, g.root)
53-
}
53+
switch path {
5454

55-
switch path {
55+
case g.root:
5656

57-
case g.root:
57+
// we can just iterate the index entries
58+
for _, entry := range gitIndex.Entries {
5859

59-
// we can just iterate the index entries
60-
for _, entry := range gitIndex.Entries {
60+
// we only want regular files, not directories or symlinks
61+
if entry.Mode == filemode.Dir || entry.Mode == filemode.Symlink {
62+
continue
63+
}
6164

62-
// we only want regular files, not directories or symlinks
63-
if entry.Mode == filemode.Dir || entry.Mode == filemode.Symlink {
64-
continue
65-
}
65+
// stat the file
66+
path := filepath.Join(g.root, entry.Name)
6667

67-
// stat the file
68-
path := filepath.Join(g.root, entry.Name)
68+
info, err := os.Lstat(path)
69+
if os.IsNotExist(err) {
70+
// the underlying file might have been removed without the change being staged yet
71+
g.log.Warnf("Path %s is in the index but appears to have been removed from the filesystem", path)
72+
continue
73+
} else if err != nil {
74+
return fmt.Errorf("failed to stat %s: %w", path, err)
75+
}
6976

70-
info, err := os.Lstat(path)
71-
if os.IsNotExist(err) {
72-
// the underlying file might have been removed without the change being staged yet
73-
g.log.Warnf("Path %s is in the index but appears to have been removed from the filesystem", path)
74-
continue
75-
} else if err != nil {
76-
return fmt.Errorf("failed to stat %s: %w", path, err)
77-
}
77+
// determine a relative path
78+
relPath, err := filepath.Rel(g.root, path)
79+
if err != nil {
80+
return fmt.Errorf("failed to determine a relative path for %s: %w", path, err)
81+
}
7882

79-
// determine a relative path
80-
relPath, err := filepath.Rel(g.root, path)
81-
if err != nil {
82-
return fmt.Errorf("failed to determine a relative path for %s: %w", path, err)
83-
}
83+
file := File{
84+
Path: path,
85+
RelPath: relPath,
86+
Info: info,
87+
}
8488

85-
file := File{
86-
Path: path,
87-
RelPath: relPath,
88-
Info: info,
89-
}
89+
g.stats.Add(stats.Traversed, 1)
90+
g.filesCh <- &file
91+
}
9092

91-
g.stats.Add(stats.Traversed, 1)
92-
g.filesCh <- &file
93-
}
93+
default:
9494

95-
default:
95+
// read the git index into memory if it hasn't already
96+
if idxCache == nil {
97+
idxCache = &filetree{name: ""}
98+
idxCache.readIndex(gitIndex)
99+
}
96100

97-
// read the git index into memory if it hasn't already
98-
if idxCache == nil {
99-
idxCache = &filetree{name: ""}
100-
idxCache.readIndex(gitIndex)
101+
// git index entries are relative to the repository root, so we need to determine a relative path for the
102+
// one we are currently processing before checking if it exists within the git index
103+
relPath, err := filepath.Rel(g.root, path)
104+
if err != nil {
105+
return fmt.Errorf("failed to find root relative path for %v: %w", path, err)
106+
}
107+
108+
if !idxCache.hasPath(relPath) {
109+
log.Debugf("path %s not found in git index, skipping", relPath)
110+
return nil
111+
}
112+
113+
err = filepath.Walk(path, func(path string, info fs.FileInfo, _ error) error {
114+
// skip directories
115+
if info.IsDir() {
116+
return nil
101117
}
102118

103-
// git index entries are relative to the repository root, so we need to determine a relative path for the
104-
// one we are currently processing before checking if it exists within the git index
119+
// determine a path relative to g.root before checking presence in the git index
105120
relPath, err := filepath.Rel(g.root, path)
106121
if err != nil {
107-
return fmt.Errorf("failed to find root relative path for %v: %w", path, err)
122+
return fmt.Errorf("failed to determine a relative path for %s: %w", path, err)
108123
}
109124

110125
if !idxCache.hasPath(relPath) {
111-
log.Debugf("path %s not found in git index, skipping", relPath)
112-
continue
126+
log.Debugf("path %v not found in git index, skipping", relPath)
127+
return nil
113128
}
114129

115-
err = filepath.Walk(path, func(path string, info fs.FileInfo, _ error) error {
116-
// skip directories
117-
if info.IsDir() {
118-
return nil
119-
}
120-
121-
// determine a path relative to g.root before checking presence in the git index
122-
relPath, err := filepath.Rel(g.root, path)
123-
if err != nil {
124-
return fmt.Errorf("failed to determine a relative path for %s: %w", path, err)
125-
}
126-
127-
if !idxCache.hasPath(relPath) {
128-
log.Debugf("path %v not found in git index, skipping", relPath)
129-
return nil
130-
}
131-
132-
file := File{
133-
Path: path,
134-
RelPath: relPath,
135-
Info: info,
136-
}
137-
138-
g.stats.Add(stats.Traversed, 1)
139-
g.filesCh <- &file
140-
return nil
141-
})
142-
if err != nil {
143-
return fmt.Errorf("failed to walk %s: %w", path, err)
130+
file := File{
131+
Path: path,
132+
RelPath: relPath,
133+
Info: info,
144134
}
135+
136+
g.stats.Add(stats.Traversed, 1)
137+
g.filesCh <- &file
138+
return nil
139+
})
140+
if err != nil {
141+
return fmt.Errorf("failed to walk %s: %w", path, err)
145142
}
146143
}
147144

@@ -175,7 +172,7 @@ func (g *GitReader) Close() error {
175172

176173
func NewGitReader(
177174
root string,
178-
paths []string,
175+
path string,
179176
statz *stats.Stats,
180177
batchSize int,
181178
) (*GitReader, error) {
@@ -186,13 +183,9 @@ func NewGitReader(
186183

187184
eg := &errgroup.Group{}
188185

189-
if len(paths) == 0 {
190-
paths = []string{"."}
191-
}
192-
193186
r := &GitReader{
194187
root: root,
195-
paths: paths,
188+
path: path,
196189
stats: statz,
197190
batchSize: batchSize,
198191
log: log.WithPrefix("walk[git]"),

‎walk/git_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestGitReader(t *testing.T) {
4040

4141
statz := stats.New()
4242

43-
reader, err := walk.NewGitReader(tempDir, nil, &statz, 1024)
43+
reader, err := walk.NewGitReader(tempDir, "", &statz, 1024)
4444
as.NoError(err)
4545

4646
count := 0

‎walk/stdin.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package walk
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"os"
8+
"path/filepath"
9+
10+
"github.com/numtide/treefmt/stats"
11+
)
12+
13+
type StdinReader struct {
14+
root string
15+
path string
16+
stats stats.Stats
17+
input *os.File
18+
19+
complete bool
20+
}
21+
22+
func (s StdinReader) Read(_ context.Context, files []*File) (n int, err error) {
23+
if s.complete {
24+
return 0, io.EOF
25+
}
26+
27+
// read stdin into a temporary file with the same file extension
28+
pattern := fmt.Sprintf("*%s", filepath.Ext(s.path))
29+
30+
file, err := os.CreateTemp(s.root, pattern)
31+
if err != nil {
32+
return 0, fmt.Errorf("failed to create a temporary file for processing stdin: %w", err)
33+
}
34+
defer file.Close()
35+
36+
if _, err = io.Copy(file, s.input); err != nil {
37+
return 0, fmt.Errorf("failed to copy stdin into a temporary file")
38+
}
39+
40+
info, err := file.Stat()
41+
if err != nil {
42+
return 0, fmt.Errorf("failed to get file info for temporary file: %w", err)
43+
}
44+
45+
relPath, err := filepath.Rel(s.root, file.Name())
46+
if err != nil {
47+
return 0, fmt.Errorf("failed to get relative path for temporary file: %w", err)
48+
}
49+
50+
files[0] = &File{
51+
Path: file.Name(),
52+
RelPath: relPath,
53+
Info: info,
54+
}
55+
56+
// dump the temp file to stdout and remove it once the file is finished being processed
57+
files[0].AddReleaseFunc(func() error {
58+
// open the temp file
59+
file, err := os.Open(file.Name())
60+
if err != nil {
61+
return fmt.Errorf("failed to open temp file %s: %w", file.Name(), err)
62+
}
63+
64+
// dump file into stdout
65+
if _, err = io.Copy(os.Stdout, file); err != nil {
66+
return fmt.Errorf("failed to copy %s to stdout: %w", file.Name(), err)
67+
}
68+
69+
if err = file.Close(); err != nil {
70+
return fmt.Errorf("failed to close temp file %s: %w", file.Name(), err)
71+
}
72+
73+
if err = os.Remove(file.Name()); err != nil {
74+
return fmt.Errorf("failed to remove temp file %s: %w", file.Name(), err)
75+
}
76+
77+
return nil
78+
})
79+
80+
s.complete = true
81+
s.stats.Add(stats.Traversed, 1)
82+
83+
return 1, io.EOF
84+
}
85+
86+
func (s StdinReader) Close() error {
87+
return nil
88+
}
89+
90+
func NewStdinReader(root string, path string, statz *stats.Stats) Reader {
91+
return StdinReader{
92+
root: root,
93+
path: path,
94+
stats: *statz,
95+
input: os.Stdin,
96+
}
97+
}

‎walk/type_enum.go

+14-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎walk/walk.go

+136-13
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ package walk
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"io"
68
"io/fs"
79
"os"
10+
"path/filepath"
811
"time"
912

1013
"github.com/numtide/treefmt/stats"
@@ -19,12 +22,14 @@ const (
1922
Auto Type = iota
2023
Git
2124
Filesystem
25+
Stdin
2226

2327
BatchSize = 1024
2428
)
2529

26-
// File represents a file object with its path, relative path, file info, and potential cached entry.
27-
// It provides an optional release function to trigger a cache update after processing.
30+
type ReleaseFunc func() error
31+
32+
// File represents a file object with its path, relative path, file info, and potential cache entry.
2833
type File struct {
2934
Path string
3035
RelPath string
@@ -33,14 +38,28 @@ type File struct {
3338
// Cache is the latest entry found for this file, if one exists.
3439
Cache *cache.Entry
3540

36-
// An optional function to be invoked when this File has finished processing.
37-
// Typically used to trigger a cache update.
38-
Release func()
41+
releaseFuncs []ReleaseFunc
42+
}
43+
44+
// Release invokes all registered release functions for the File.
45+
// If any release function returns an error, Release stops and returns that error.
46+
func (f *File) Release() error {
47+
for _, fn := range f.releaseFuncs {
48+
if err := fn(); err != nil {
49+
return err
50+
}
51+
}
52+
return nil
53+
}
54+
55+
// AddReleaseFunc adds a release function to the File's list of release functions.
56+
func (f *File) AddReleaseFunc(fn ReleaseFunc) {
57+
f.releaseFuncs = append(f.releaseFuncs, fn)
3958
}
4059

4160
// Stat checks if the file has changed by comparing its current state (size, mod time) to when it was first read.
4261
// It returns a boolean indicating if the file has changed, the current file info, and an error if any.
43-
func (f File) Stat() (bool, fs.FileInfo, error) {
62+
func (f *File) Stat() (changed bool, info fs.FileInfo, err error) {
4463
// Get the file's current state
4564
current, err := os.Stat(f.Path)
4665
if err != nil {
@@ -64,7 +83,7 @@ func (f File) Stat() (bool, fs.FileInfo, error) {
6483
}
6584

6685
// String returns the file's path as a string.
67-
func (f File) String() string {
86+
func (f *File) String() string {
6887
return f.Path
6988
}
7089

@@ -74,11 +93,56 @@ type Reader interface {
7493
Close() error
7594
}
7695

77-
// NewReader creates a new instance of Reader based on the given walkType (Auto, Git, Filesystem).
96+
// CompositeReader combines multiple Readers into one.
97+
// It iterates over the given readers, reading each until completion.
98+
type CompositeReader struct {
99+
idx int
100+
current Reader
101+
readers []Reader
102+
}
103+
104+
func (c *CompositeReader) Read(ctx context.Context, files []*File) (n int, err error) {
105+
if c.current == nil {
106+
// check if we have exhausted all the readers
107+
if c.idx >= len(c.readers) {
108+
return 0, io.EOF
109+
}
110+
111+
// if not, select the next reader
112+
c.current = c.readers[c.idx]
113+
c.idx++
114+
}
115+
116+
// attempt a read
117+
n, err = c.current.Read(ctx, files)
118+
119+
// check if the current reader has been exhausted
120+
if errors.Is(err, io.EOF) {
121+
// reset the error if it's EOF
122+
err = nil
123+
// set the current reader to nil so we try to read from the next reader on the next call
124+
c.current = nil
125+
} else if err != nil {
126+
err = fmt.Errorf("failed to read from current reader: %w", err)
127+
}
128+
129+
// return the number of files read in this call and any error
130+
return n, err
131+
}
132+
133+
func (c *CompositeReader) Close() error {
134+
for _, reader := range c.readers {
135+
if err := reader.Close(); err != nil {
136+
return fmt.Errorf("failed to close reader: %w", err)
137+
}
138+
}
139+
return nil
140+
}
141+
78142
func NewReader(
79143
walkType Type,
80144
root string,
81-
paths []string,
145+
path string,
82146
db *bolt.DB,
83147
statz *stats.Stats,
84148
) (Reader, error) {
@@ -90,15 +154,17 @@ func NewReader(
90154
switch walkType {
91155
case Auto:
92156
// for now, we keep it simple and try git first, filesystem second
93-
reader, err = NewReader(Git, root, paths, db, statz)
157+
reader, err = NewReader(Git, root, path, db, statz)
94158
if err != nil {
95-
reader, err = NewReader(Filesystem, root, paths, db, statz)
159+
reader, err = NewReader(Filesystem, root, path, db, statz)
96160
}
97161
return reader, err
98162
case Git:
99-
reader, err = NewGitReader(root, paths, statz, BatchSize)
163+
reader, err = NewGitReader(root, path, statz, BatchSize)
100164
case Filesystem:
101-
reader = NewFilesystemReader(root, paths, statz, BatchSize)
165+
reader = NewFilesystemReader(root, path, statz, BatchSize)
166+
case Stdin:
167+
return nil, fmt.Errorf("stdin walk type is not supported")
102168
default:
103169
return nil, fmt.Errorf("unknown walk type: %v", walkType)
104170
}
@@ -115,3 +181,60 @@ func NewReader(
115181

116182
return reader, err
117183
}
184+
185+
func NewCompositeReader(
186+
walkType Type,
187+
root string,
188+
paths []string,
189+
db *bolt.DB,
190+
statz *stats.Stats,
191+
) (Reader, error) {
192+
// if not paths are provided we default to processing the tree root
193+
if len(paths) == 0 {
194+
return NewReader(walkType, root, "", db, statz)
195+
}
196+
197+
readers := make([]Reader, len(paths))
198+
199+
// check we have received 1 path for the stdin walk type
200+
if walkType == Stdin {
201+
if len(paths) != 1 {
202+
return nil, fmt.Errorf("stdin walk requires exactly one path")
203+
}
204+
205+
return NewStdinReader(root, paths[0], statz), nil
206+
}
207+
208+
// create a reader for each provided path
209+
for idx, relPath := range paths {
210+
var (
211+
err error
212+
info os.FileInfo
213+
)
214+
215+
// create a clean absolute path
216+
path := filepath.Clean(filepath.Join(root, relPath))
217+
218+
// check the path exists
219+
info, err = os.Lstat(path)
220+
if err != nil {
221+
return nil, fmt.Errorf("failed to stat %s: %w", path, err)
222+
}
223+
224+
if info.IsDir() {
225+
// for directories, we honour the walk type as we traverse them
226+
readers[idx], err = NewReader(walkType, root, relPath, db, statz)
227+
} else {
228+
// for files, we enforce a simple filesystem read
229+
readers[idx], err = NewReader(Filesystem, root, relPath, db, statz)
230+
}
231+
232+
if err != nil {
233+
return nil, fmt.Errorf("failed to create reader for %s: %w", relPath, err)
234+
}
235+
}
236+
237+
return &CompositeReader{
238+
readers: readers,
239+
}, nil
240+
}

0 commit comments

Comments
 (0)
Please sign in to comment.