Skip to content

Commit 76bde95

Browse files
committed
use a rsa private key to sign jwt
if not given will auto-generate a HMAC key
1 parent 59bb2bf commit 76bde95

File tree

4 files changed

+62
-21
lines changed

4 files changed

+62
-21
lines changed

cmd/gaia/main.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package main
22

33
import (
4+
"crypto/rand"
45
"flag"
56
"fmt"
67
"os"
78
"path/filepath"
89

10+
"io/ioutil"
11+
12+
"github.com/dgrijalva/jwt-go"
913
"github.com/gaia-pipeline/gaia"
1014
"github.com/gaia-pipeline/gaia/handlers"
1115
"github.com/gaia-pipeline/gaia/pipeline"
@@ -16,7 +20,8 @@ import (
1620
)
1721

1822
var (
19-
echoInstance *echo.Echo
23+
echoInstance *echo.Echo
24+
jwtPrivateKeyPath string
2025
)
2126

2227
const (
@@ -35,6 +40,7 @@ func init() {
3540
flag.StringVar(&gaia.Cfg.ListenPort, "port", "8080", "Listen port for gaia")
3641
flag.StringVar(&gaia.Cfg.HomePath, "homepath", "", "Path to the gaia home folder")
3742
flag.StringVar(&gaia.Cfg.Worker, "worker", "2", "Number of worker gaia will use to execute pipelines in parallel")
43+
flag.StringVar(&jwtPrivateKeyPath, "jwtPrivateKeyPath", "", "A RSA private key used to sign JWT tokens")
3844
flag.BoolVar(&gaia.Cfg.DevMode, "dev", false, "If true, gaia will be started in development mode. Don't use this in production!")
3945
flag.BoolVar(&gaia.Cfg.VersionSwitch, "version", false, "If true, will print the version and immediately exit")
4046

@@ -59,6 +65,30 @@ func main() {
5965
Name: "Gaia",
6066
})
6167

68+
var jwtKey interface{}
69+
// Check JWT key is set
70+
if jwtPrivateKeyPath == "" {
71+
gaia.Cfg.Logger.Warn("using auto-generated key to sign jwt tokens, do not use in production")
72+
jwtKey = make([]byte, 64)
73+
_, err := rand.Read(jwtKey.([]byte))
74+
if err != nil {
75+
gaia.Cfg.Logger.Error("error auto-generating jwt key", "error", err.Error())
76+
os.Exit(1)
77+
}
78+
} else {
79+
keyData, err := ioutil.ReadFile(jwtPrivateKeyPath)
80+
if err != nil {
81+
gaia.Cfg.Logger.Error("could not read jwt key file", "error", err.Error())
82+
os.Exit(1)
83+
}
84+
jwtKey, err = jwt.ParseRSAPrivateKeyFromPEM(keyData)
85+
if err != nil {
86+
gaia.Cfg.Logger.Error("could not parse jwt key file", "error", err.Error())
87+
os.Exit(1)
88+
}
89+
}
90+
gaia.Cfg.JWTKey = jwtKey
91+
6292
// Find path for gaia home folder if not given by parameter
6393
if gaia.Cfg.HomePath == "" {
6494
// Find executeable path

gaia.go

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ type Config struct {
153153
PipelinePath string
154154
WorkspacePath string
155155
Worker string
156+
JWTKey interface{}
156157
Logger hclog.Logger
157158

158159
Bolt struct {

handlers/User.go

+14-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66

77
"github.com/labstack/echo"
88

9-
jwt "github.com/dgrijalva/jwt-go"
9+
"crypto/rsa"
10+
11+
"github.com/dgrijalva/jwt-go"
1012
"github.com/gaia-pipeline/gaia"
1113
)
1214

@@ -45,11 +47,20 @@ func UserLogin(c echo.Context) error {
4547
},
4648
}
4749

50+
var token *jwt.Token
4851
// Generate JWT token
49-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
52+
switch t := gaia.Cfg.JWTKey.(type) {
53+
case []byte:
54+
token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
55+
case *rsa.PrivateKey:
56+
token = jwt.NewWithClaims(jwt.SigningMethodRS512, claims)
57+
default:
58+
gaia.Cfg.Logger.Error("invalid jwt key type", "type", t)
59+
return c.String(http.StatusInternalServerError, "error creating jwt token: invalid jwt key type")
60+
}
5061

5162
// Sign and get encoded token
52-
tokenstring, err := token.SignedString(jwtKey)
63+
tokenstring, err := token.SignedString(gaia.Cfg.JWTKey)
5364
if err != nil {
5465
gaia.Cfg.Logger.Error("error signing jwt token", "error", err.Error())
5566
return c.String(http.StatusInternalServerError, err.Error())

handlers/handler.go

+16-17
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package handlers
22

33
import (
4-
"crypto/rand"
54
"errors"
65
"fmt"
76
"net/http"
87
"strings"
98

109
"github.com/GeertJohan/go.rice"
1110

11+
"crypto/rsa"
12+
1213
jwt "github.com/dgrijalva/jwt-go"
1314
"github.com/gaia-pipeline/gaia"
1415
scheduler "github.com/gaia-pipeline/gaia/scheduler"
@@ -48,22 +49,12 @@ var storeService *store.Store
4849

4950
var schedulerService *scheduler.Scheduler
5051

51-
// jwtKey is a random generated key for jwt signing
52-
var jwtKey []byte
53-
5452
// InitHandlers initializes(registers) all handlers
5553
func InitHandlers(e *echo.Echo, store *store.Store, scheduler *scheduler.Scheduler) error {
5654
// Set instances
5755
storeService = store
5856
schedulerService = scheduler
5957

60-
// Generate signing key for jwt
61-
jwtKey = make([]byte, 64)
62-
_, err := rand.Read(jwtKey)
63-
if err != nil {
64-
return err
65-
}
66-
6758
// Define prefix
6859
p := "/api/" + apiVersion + "/"
6960

@@ -142,13 +133,21 @@ func authBarrier(next echo.HandlerFunc) echo.HandlerFunc {
142133

143134
// Parse token
144135
token, err := jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
145-
// Validate signing method
146-
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
147-
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
136+
signingMethodError := fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
137+
switch token.Method.(type) {
138+
case *jwt.SigningMethodHMAC:
139+
if _, ok := gaia.Cfg.JWTKey.([]byte); !ok {
140+
return nil, signingMethodError
141+
}
142+
return gaia.Cfg.JWTKey, nil
143+
case *jwt.SigningMethodRSA:
144+
if _, ok := gaia.Cfg.JWTKey.(*rsa.PrivateKey); !ok {
145+
return nil, signingMethodError
146+
}
147+
return gaia.Cfg.JWTKey.(*rsa.PrivateKey).Public(), nil
148+
default:
149+
return nil, signingMethodError
148150
}
149-
150-
// return secret
151-
return jwtKey, nil
152151
})
153152
if err != nil {
154153
return c.String(http.StatusForbidden, err.Error())

0 commit comments

Comments
 (0)