Skip to content

Commit 7a2915a

Browse files
authored
Simplify tests with T.TempDir (sashabaranov#929)
1 parent 2a0ff5a commit 7a2915a

8 files changed

+24
-70
lines changed

.golangci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ linters:
206206
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
207207
- unconvert # Remove unnecessary type conversions
208208
- unparam # Reports unused function parameters
209+
- usetesting # Reports uses of functions with replacement inside the testing package
209210
- wastedassign # wastedassign finds wasted assignment statements.
210211
- whitespace # Tool for detection of leading and trailing whitespace
211212
## you may want to enable

audio_api_test.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,9 @@ func TestAudio(t *testing.T) {
4040

4141
ctx := context.Background()
4242

43-
dir, cleanup := test.CreateTestDirectory(t)
44-
defer cleanup()
45-
4643
for _, tc := range testcases {
4744
t.Run(tc.name, func(t *testing.T) {
48-
path := filepath.Join(dir, "fake.mp3")
45+
path := filepath.Join(t.TempDir(), "fake.mp3")
4946
test.CreateTestFile(t, path)
5047

5148
req := openai.AudioRequest{
@@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) {
9087

9188
ctx := context.Background()
9289

93-
dir, cleanup := test.CreateTestDirectory(t)
94-
defer cleanup()
95-
9690
for _, tc := range testcases {
9791
t.Run(tc.name, func(t *testing.T) {
98-
path := filepath.Join(dir, "fake.mp3")
92+
path := filepath.Join(t.TempDir(), "fake.mp3")
9993
test.CreateTestFile(t, path)
10094

10195
req := openai.AudioRequest{

audio_test.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ import (
1313
)
1414

1515
func TestAudioWithFailingFormBuilder(t *testing.T) {
16-
dir, cleanup := test.CreateTestDirectory(t)
17-
defer cleanup()
18-
path := filepath.Join(dir, "fake.mp3")
16+
path := filepath.Join(t.TempDir(), "fake.mp3")
1917
test.CreateTestFile(t, path)
2018

2119
req := AudioRequest{
@@ -63,9 +61,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
6361

6462
func TestCreateFileField(t *testing.T) {
6563
t.Run("createFileField failing file", func(t *testing.T) {
66-
dir, cleanup := test.CreateTestDirectory(t)
67-
defer cleanup()
68-
path := filepath.Join(dir, "fake.mp3")
64+
path := filepath.Join(t.TempDir(), "fake.mp3")
6965
test.CreateTestFile(t, path)
7066

7167
req := AudioRequest{

image_api_test.go

+13-29
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"os"
10+
"path/filepath"
1011
"testing"
1112
"time"
1213

@@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) {
8687
defer teardown()
8788
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
8889

89-
origin, err := os.Create("image.png")
90+
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
9091
if err != nil {
91-
t.Error("open origin file error")
92-
return
92+
t.Fatalf("open origin file error: %v", err)
9393
}
94+
defer origin.Close()
9495

95-
mask, err := os.Create("mask.png")
96+
mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png"))
9697
if err != nil {
97-
t.Error("open mask file error")
98-
return
98+
t.Fatalf("open mask file error: %v", err)
9999
}
100-
101-
defer func() {
102-
mask.Close()
103-
origin.Close()
104-
os.Remove("mask.png")
105-
os.Remove("image.png")
106-
}()
100+
defer mask.Close()
107101

108102
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
109103
Image: origin,
@@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) {
121115
defer teardown()
122116
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
123117

124-
origin, err := os.Create("image.png")
118+
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
125119
if err != nil {
126-
t.Error("open origin file error")
127-
return
120+
t.Fatalf("open origin file error: %v", err)
128121
}
129-
130-
defer func() {
131-
origin.Close()
132-
os.Remove("image.png")
133-
}()
122+
defer origin.Close()
134123

135124
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
136125
Image: origin,
@@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) {
178167
defer teardown()
179168
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
180169

181-
origin, err := os.Create("image.png")
170+
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
182171
if err != nil {
183-
t.Error("open origin file error")
184-
return
172+
t.Fatalf("open origin file error: %v", err)
185173
}
186-
187-
defer func() {
188-
origin.Close()
189-
os.Remove("image.png")
190-
}()
174+
defer origin.Close()
191175

192176
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
193177
Image: origin,

internal/form_builder_test.go

+4-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package openai //nolint:testpackage // testing private field
22

33
import (
4-
"github.com/sashabaranov/go-openai/internal/test"
54
"github.com/sashabaranov/go-openai/internal/test/checks"
65

76
"bytes"
@@ -20,31 +19,23 @@ func (*failingWriter) Write([]byte) (int, error) {
2019
}
2120

2221
func TestFormBuilderWithFailingWriter(t *testing.T) {
23-
dir, cleanup := test.CreateTestDirectory(t)
24-
defer cleanup()
25-
26-
file, err := os.CreateTemp(dir, "")
22+
file, err := os.CreateTemp(t.TempDir(), "")
2723
if err != nil {
28-
t.Errorf("Error creating tmp file: %v", err)
24+
t.Fatalf("Error creating tmp file: %v", err)
2925
}
3026
defer file.Close()
31-
defer os.Remove(file.Name())
3227

3328
builder := NewFormBuilder(&failingWriter{})
3429
err = builder.CreateFormFile("file", file)
3530
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
3631
}
3732

3833
func TestFormBuilderWithClosedFile(t *testing.T) {
39-
dir, cleanup := test.CreateTestDirectory(t)
40-
defer cleanup()
41-
42-
file, err := os.CreateTemp(dir, "")
34+
file, err := os.CreateTemp(t.TempDir(), "")
4335
if err != nil {
44-
t.Errorf("Error creating tmp file: %v", err)
36+
t.Fatalf("Error creating tmp file: %v", err)
4537
}
4638
file.Close()
47-
defer os.Remove(file.Name())
4839

4940
body := &bytes.Buffer{}
5041
builder := NewFormBuilder(body)

internal/test/helpers.go

-10
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) {
1919
file.Close()
2020
}
2121

22-
// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called.
23-
func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
24-
t.Helper()
25-
26-
path, err := os.MkdirTemp(os.TempDir(), "")
27-
checks.NoError(t, err)
28-
29-
return path, func() { os.RemoveAll(path) }
30-
}
31-
3222
// TokenRoundTripper is a struct that implements the RoundTripper
3323
// interface, specifically to handle the authentication token by adding a token
3424
// to the request header. We need this because the API requires that each

openai_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
3131
// This function approximates based on the rule of thumb stated by OpenAI:
3232
// https://beta.openai.com/tokenizer
3333
//
34-
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
34+
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
3535
func numTokens(s string) int {
3636
return int(float32(len(s)) / 4)
3737
}

speech_test.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) {
2121
defer teardown()
2222

2323
server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
24-
dir, cleanup := test.CreateTestDirectory(t)
25-
path := filepath.Join(dir, "fake.mp3")
24+
path := filepath.Join(t.TempDir(), "fake.mp3")
2625
test.CreateTestFile(t, path)
27-
defer cleanup()
2826

2927
// audio endpoints only accept POST requests
3028
if r.Method != "POST" {

0 commit comments

Comments
 (0)