Skip to content

Commit 8c604b0

Browse files
committed
refactor
1 parent f3eb835 commit 8c604b0

File tree

9 files changed

+57
-46
lines changed

9 files changed

+57
-46
lines changed

cmd/generate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func runGenerateInternalToken(c *cli.Context) error {
7070
}
7171

7272
func runGenerateLfsJwtSecret(c *cli.Context) error {
73-
_, jwtSecretBase64, err := generate.NewJwtSecretBase64()
73+
_, jwtSecretBase64, err := generate.NewJwtSecretWithBase64()
7474
if err != nil {
7575
return err
7676
}

modules/generate/generate.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package generate
77
import (
88
"crypto/rand"
99
"encoding/base64"
10+
"fmt"
1011
"io"
1112
"time"
1213

@@ -38,19 +39,24 @@ func NewInternalToken() (string, error) {
3839
return internalToken, nil
3940
}
4041

41-
// NewJwtSecret generates a new value intended to be used for JWT secrets.
42-
func NewJwtSecret() ([]byte, error) {
43-
bytes := make([]byte, 32)
44-
_, err := io.ReadFull(rand.Reader, bytes)
45-
if err != nil {
42+
const defaultJwtSecretLen = 32
43+
44+
// DecodeJwtSecretBase64 decodes a base64 encoded jwt secret into bytes, and check its length
45+
func DecodeJwtSecretBase64(src string) ([]byte, error) {
46+
encoding := base64.RawURLEncoding
47+
decoded := make([]byte, encoding.DecodedLen(len(src))+3)
48+
if n, err := encoding.Decode(decoded, []byte(src)); err != nil {
4649
return nil, err
50+
} else if n != defaultJwtSecretLen {
51+
return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, defaultJwtSecretLen)
4752
}
48-
return bytes, nil
53+
return decoded[:defaultJwtSecretLen], nil
4954
}
5055

51-
// NewJwtSecretBase64 generates a new base64 encoded value intended to be used for JWT secrets.
52-
func NewJwtSecretBase64() ([]byte, string, error) {
53-
bytes, err := NewJwtSecret()
56+
// NewJwtSecretWithBase64 generates a jwt secret with its base64 encoded value intended to be used for saving into config file
57+
func NewJwtSecretWithBase64() ([]byte, string, error) {
58+
bytes := make([]byte, defaultJwtSecretLen)
59+
_, err := io.ReadFull(rand.Reader, bytes)
5460
if err != nil {
5561
return nil, "", err
5662
}

modules/generate/generate_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2024 The Gitea Authors. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
package generate
5+
6+
import (
7+
"encoding/base64"
8+
"strings"
9+
"testing"
10+
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
func TestDecodeJwtSecretBase64(t *testing.T) {
15+
_, err := DecodeJwtSecretBase64("abcd")
16+
assert.ErrorContains(t, err, "invalid base64 decoded length")
17+
_, err = DecodeJwtSecretBase64(strings.Repeat("a", 64))
18+
assert.ErrorContains(t, err, "invalid base64 decoded length")
19+
20+
str32 := strings.Repeat("x", 32)
21+
encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32))
22+
decoded32, err := DecodeJwtSecretBase64(encoded32)
23+
assert.NoError(t, err)
24+
assert.Equal(t, str32, string(decoded32))
25+
}
26+
27+
func TestNewJwtSecretWithBase64(t *testing.T) {
28+
secret, encoded, err := NewJwtSecretWithBase64()
29+
assert.NoError(t, err)
30+
assert.Len(t, secret, 32)
31+
decoded, err := DecodeJwtSecretBase64(encoded)
32+
assert.NoError(t, err)
33+
assert.Equal(t, secret, decoded)
34+
}

modules/setting/lfs.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
package setting
55

66
import (
7-
"encoding/base64"
87
"fmt"
98
"time"
109

1110
"code.gitea.io/gitea/modules/generate"
12-
"code.gitea.io/gitea/modules/util"
1311
)
1412

1513
// LFS represents the configuration for Git LFS
@@ -62,9 +60,9 @@ func loadLFSFrom(rootCfg ConfigProvider) error {
6260
}
6361

6462
LFS.JWTSecretBase64 = loadSecret(rootCfg.Section("server"), "LFS_JWT_SECRET_URI", "LFS_JWT_SECRET")
65-
LFS.JWTSecretBytes, err = util.Base64FixedDecode(base64.RawURLEncoding, []byte(LFS.JWTSecretBase64), 32)
63+
LFS.JWTSecretBytes, err = generate.DecodeJwtSecretBase64(LFS.JWTSecretBase64)
6664
if err != nil {
67-
LFS.JWTSecretBytes, LFS.JWTSecretBase64, err = generate.NewJwtSecretBase64()
65+
LFS.JWTSecretBytes, LFS.JWTSecretBase64, err = generate.NewJwtSecretWithBase64()
6866
if err != nil {
6967
return fmt.Errorf("error generating JWT Secret for custom config: %v", err)
7068
}

modules/setting/oauth2.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
package setting
55

66
import (
7-
"encoding/base64"
87
"math"
98
"path/filepath"
109

1110
"code.gitea.io/gitea/modules/generate"
1211
"code.gitea.io/gitea/modules/log"
13-
"code.gitea.io/gitea/modules/util"
1412
)
1513

1614
// OAuth2UsernameType is enum describing the way gitea 'name' should be generated from oauth2 data
@@ -137,13 +135,12 @@ func loadOAuth2From(rootCfg ConfigProvider) {
137135
}
138136

139137
if InstallLock {
140-
if _, err := util.Base64FixedDecode(base64.RawURLEncoding, []byte(OAuth2.JWTSecretBase64), 32); err != nil {
141-
key, err := generate.NewJwtSecret()
138+
if _, err := generate.DecodeJwtSecretBase64(OAuth2.JWTSecretBase64); err != nil {
139+
_, OAuth2.JWTSecretBase64, err = generate.NewJwtSecretWithBase64()
142140
if err != nil {
143141
log.Fatal("error generating JWT secret: %v", err)
144142
}
145143

146-
OAuth2.JWTSecretBase64 = base64.RawURLEncoding.EncodeToString(key)
147144
saveCfg, err := rootCfg.PrepareSaving()
148145
if err != nil {
149146
log.Fatal("save oauth2.JWT_SECRET failed: %v", err)

modules/util/util.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package util
66
import (
77
"bytes"
88
"crypto/rand"
9-
"encoding/base64"
109
"fmt"
1110
"math/big"
1211
"strconv"
@@ -246,13 +245,3 @@ func ToFloat64(number any) (float64, error) {
246245
func ToPointer[T any](val T) *T {
247246
return &val
248247
}
249-
250-
func Base64FixedDecode(encoding *base64.Encoding, src []byte, length int) ([]byte, error) {
251-
decoded := make([]byte, encoding.DecodedLen(len(src))+3)
252-
if n, err := encoding.Decode(decoded, src); err != nil {
253-
return nil, err
254-
} else if n != length {
255-
return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, length)
256-
}
257-
return decoded[:length], nil
258-
}

modules/util/util_test.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
package util
55

66
import (
7-
"encoding/base64"
87
"regexp"
98
"strings"
109
"testing"
@@ -234,16 +233,3 @@ func TestToPointer(t *testing.T) {
234233
val123 := 123
235234
assert.False(t, &val123 == ToPointer(val123))
236235
}
237-
238-
func TestBase64FixedDecode(t *testing.T) {
239-
_, err := Base64FixedDecode(base64.RawURLEncoding, []byte("abcd"), 32)
240-
assert.ErrorContains(t, err, "invalid base64 decoded length")
241-
_, err = Base64FixedDecode(base64.RawURLEncoding, []byte(strings.Repeat("a", 64)), 32)
242-
assert.ErrorContains(t, err, "invalid base64 decoded length")
243-
244-
str32 := strings.Repeat("x", 32)
245-
encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32))
246-
decoded32, err := Base64FixedDecode(base64.RawURLEncoding, []byte(encoded32), 32)
247-
assert.NoError(t, err)
248-
assert.Equal(t, str32, string(decoded32))
249-
}

routers/install/install.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ func SubmitInstall(ctx *context.Context) {
409409
cfg.Section("server").Key("LFS_START_SERVER").SetValue("true")
410410
cfg.Section("lfs").Key("PATH").SetValue(form.LFSRootPath)
411411
var lfsJwtSecret string
412-
if _, lfsJwtSecret, err = generate.NewJwtSecretBase64(); err != nil {
412+
if _, lfsJwtSecret, err = generate.NewJwtSecretWithBase64(); err != nil {
413413
ctx.RenderWithErr(ctx.Tr("install.lfs_jwt_secret_failed", err), tplInstall, &form)
414414
return
415415
}

services/auth/source/oauth2/jwtsigningkey.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"path/filepath"
1919
"strings"
2020

21+
"code.gitea.io/gitea/modules/generate"
2122
"code.gitea.io/gitea/modules/log"
2223
"code.gitea.io/gitea/modules/setting"
2324
"code.gitea.io/gitea/modules/util"
@@ -336,7 +337,7 @@ func InitSigningKey() error {
336337
// loadSymmetricKey checks if the configured secret is valid.
337338
// If it is not valid, it will return an error.
338339
func loadSymmetricKey() (any, error) {
339-
return util.Base64FixedDecode(base64.RawURLEncoding, []byte(setting.OAuth2.JWTSecretBase64), 32)
340+
return generate.DecodeJwtSecretBase64(setting.OAuth2.JWTSecretBase64)
340341
}
341342

342343
// loadOrCreateAsymmetricKey checks if the configured private key exists.

0 commit comments

Comments
 (0)