Skip to content

Commit 09ad3d3

Browse files
authored
Merge pull request #27 from rmb938/jwtflag
use rsa key for jwt
2 parents 59bb2bf + 9cfb17a commit 09ad3d3

File tree

6 files changed

+428
-29
lines changed

6 files changed

+428
-29
lines changed

cmd/gaia/main.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package main
22

33
import (
4+
"crypto/rand"
45
"flag"
56
"fmt"
67
"os"
78
"path/filepath"
89

10+
"io/ioutil"
11+
12+
"github.com/dgrijalva/jwt-go"
913
"github.com/gaia-pipeline/gaia"
1014
"github.com/gaia-pipeline/gaia/handlers"
1115
"github.com/gaia-pipeline/gaia/pipeline"
@@ -35,6 +39,7 @@ func init() {
3539
flag.StringVar(&gaia.Cfg.ListenPort, "port", "8080", "Listen port for gaia")
3640
flag.StringVar(&gaia.Cfg.HomePath, "homepath", "", "Path to the gaia home folder")
3741
flag.StringVar(&gaia.Cfg.Worker, "worker", "2", "Number of worker gaia will use to execute pipelines in parallel")
42+
flag.StringVar(&gaia.Cfg.JwtPrivateKeyPath, "jwtPrivateKeyPath", "", "A RSA private key used to sign JWT tokens")
3843
flag.BoolVar(&gaia.Cfg.DevMode, "dev", false, "If true, gaia will be started in development mode. Don't use this in production!")
3944
flag.BoolVar(&gaia.Cfg.VersionSwitch, "version", false, "If true, will print the version and immediately exit")
4045

@@ -59,6 +64,30 @@ func main() {
5964
Name: "Gaia",
6065
})
6166

67+
var jwtKey interface{}
68+
// Check JWT key is set
69+
if gaia.Cfg.JwtPrivateKeyPath == "" {
70+
gaia.Cfg.Logger.Warn("using auto-generated key to sign jwt tokens, do not use in production")
71+
jwtKey = make([]byte, 64)
72+
_, err := rand.Read(jwtKey.([]byte))
73+
if err != nil {
74+
gaia.Cfg.Logger.Error("error auto-generating jwt key", "error", err.Error())
75+
os.Exit(1)
76+
}
77+
} else {
78+
keyData, err := ioutil.ReadFile(gaia.Cfg.JwtPrivateKeyPath)
79+
if err != nil {
80+
gaia.Cfg.Logger.Error("could not read jwt key file", "error", err.Error())
81+
os.Exit(1)
82+
}
83+
jwtKey, err = jwt.ParseRSAPrivateKeyFromPEM(keyData)
84+
if err != nil {
85+
gaia.Cfg.Logger.Error("could not parse jwt key file", "error", err.Error())
86+
os.Exit(1)
87+
}
88+
}
89+
gaia.Cfg.JWTKey = jwtKey
90+
6291
// Find path for gaia home folder if not given by parameter
6392
if gaia.Cfg.HomePath == "" {
6493
// Find executeable path

gaia.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,17 @@ var Cfg *Config
145145

146146
// Config holds all config options
147147
type Config struct {
148-
DevMode bool
149-
VersionSwitch bool
150-
ListenPort string
151-
HomePath string
152-
DataPath string
153-
PipelinePath string
154-
WorkspacePath string
155-
Worker string
156-
Logger hclog.Logger
148+
DevMode bool
149+
VersionSwitch bool
150+
ListenPort string
151+
HomePath string
152+
DataPath string
153+
PipelinePath string
154+
WorkspacePath string
155+
Worker string
156+
JwtPrivateKeyPath string
157+
JWTKey interface{}
158+
Logger hclog.Logger
157159

158160
Bolt struct {
159161
Mode os.FileMode

handlers/User.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66

77
"github.com/labstack/echo"
88

9-
jwt "github.com/dgrijalva/jwt-go"
9+
"crypto/rsa"
10+
11+
"github.com/dgrijalva/jwt-go"
1012
"github.com/gaia-pipeline/gaia"
1113
)
1214

@@ -45,11 +47,20 @@ func UserLogin(c echo.Context) error {
4547
},
4648
}
4749

50+
var token *jwt.Token
4851
// Generate JWT token
49-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
52+
switch t := gaia.Cfg.JWTKey.(type) {
53+
case []byte:
54+
token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
55+
case *rsa.PrivateKey:
56+
token = jwt.NewWithClaims(jwt.SigningMethodRS512, claims)
57+
default:
58+
gaia.Cfg.Logger.Error("invalid jwt key type", "type", t)
59+
return c.String(http.StatusInternalServerError, "error creating jwt token: invalid jwt key type")
60+
}
5061

5162
// Sign and get encoded token
52-
tokenstring, err := token.SignedString(jwtKey)
63+
tokenstring, err := token.SignedString(gaia.Cfg.JWTKey)
5364
if err != nil {
5465
gaia.Cfg.Logger.Error("error signing jwt token", "error", err.Error())
5566
return c.String(http.StatusInternalServerError, err.Error())

handlers/User_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package handlers
2+
3+
import (
4+
"testing"
5+
6+
"bytes"
7+
"encoding/json"
8+
"net/http"
9+
"net/http/httptest"
10+
11+
"io/ioutil"
12+
"os"
13+
14+
"crypto/rand"
15+
"crypto/rsa"
16+
17+
jwt "github.com/dgrijalva/jwt-go"
18+
"github.com/gaia-pipeline/gaia"
19+
"github.com/gaia-pipeline/gaia/store"
20+
"github.com/hashicorp/go-hclog"
21+
"github.com/labstack/echo"
22+
)
23+
24+
func TestUserLoginHMACKey(t *testing.T) {
25+
26+
dataDir, err := ioutil.TempDir("", "hmac")
27+
if err != nil {
28+
t.Fatalf("error creating data dir %v", err.Error())
29+
}
30+
31+
defer func() {
32+
gaia.Cfg = nil
33+
os.RemoveAll(dataDir)
34+
}()
35+
36+
gaia.Cfg = &gaia.Config{
37+
JWTKey: []byte("hmac-jwt-key"),
38+
Logger: hclog.New(&hclog.LoggerOptions{
39+
Level: hclog.Trace,
40+
Output: hclog.DefaultOutput,
41+
Name: "Gaia",
42+
}),
43+
DataPath: dataDir,
44+
}
45+
46+
dataStore := store.NewStore()
47+
err = dataStore.Init()
48+
if err != nil {
49+
t.Fatalf("cannot initialize store: %v", err.Error())
50+
}
51+
52+
e := echo.New()
53+
InitHandlers(e, dataStore, nil)
54+
55+
body := map[string]string{
56+
"username": "admin",
57+
"password": "admin",
58+
}
59+
bodyBytes, _ := json.Marshal(body)
60+
req := httptest.NewRequest(echo.POST, "/api/"+apiVersion+"/login", bytes.NewBuffer(bodyBytes))
61+
req.Header.Set("Content-Type", "application/json")
62+
rec := httptest.NewRecorder()
63+
e.ServeHTTP(rec, req)
64+
65+
if rec.Code != http.StatusOK {
66+
t.Fatalf("expected response code %v got %v", http.StatusOK, rec.Code)
67+
}
68+
69+
data, err := ioutil.ReadAll(rec.Body)
70+
user := &gaia.User{}
71+
err = json.Unmarshal(data, user)
72+
if err != nil {
73+
t.Fatalf("error unmarshaling responce %v", err.Error())
74+
}
75+
token, _, err := new(jwt.Parser).ParseUnverified(user.Tokenstring, jwt.MapClaims{})
76+
if err != nil {
77+
t.Fatalf("error parsing the token %v", err.Error())
78+
}
79+
alg := "HS256"
80+
if token.Header["alg"] != alg {
81+
t.Fatalf("expected token alg %v got %v", alg, token.Header["alg"])
82+
}
83+
84+
}
85+
86+
func TestUserLoginRSAKey(t *testing.T) {
87+
dataDir, err := ioutil.TempDir("", "rsa")
88+
if err != nil {
89+
t.Fatalf("error creating data dir %v", err.Error())
90+
}
91+
92+
defer func() {
93+
gaia.Cfg = nil
94+
os.RemoveAll(dataDir)
95+
}()
96+
97+
key, _ := rsa.GenerateKey(rand.Reader, 2048)
98+
gaia.Cfg = &gaia.Config{
99+
JWTKey: key,
100+
Logger: hclog.New(&hclog.LoggerOptions{
101+
Level: hclog.Trace,
102+
Output: hclog.DefaultOutput,
103+
Name: "Gaia",
104+
}),
105+
DataPath: dataDir,
106+
}
107+
108+
dataStore := store.NewStore()
109+
err = dataStore.Init()
110+
if err != nil {
111+
t.Fatalf("cannot initialize store: %v", err.Error())
112+
}
113+
114+
e := echo.New()
115+
InitHandlers(e, dataStore, nil)
116+
117+
body := map[string]string{
118+
"username": "admin",
119+
"password": "admin",
120+
}
121+
bodyBytes, _ := json.Marshal(body)
122+
req := httptest.NewRequest(echo.POST, "/api/"+apiVersion+"/login", bytes.NewBuffer(bodyBytes))
123+
req.Header.Set("Content-Type", "application/json")
124+
rec := httptest.NewRecorder()
125+
e.ServeHTTP(rec, req)
126+
127+
if rec.Code != http.StatusOK {
128+
t.Fatalf("expected response code %v got %v", http.StatusOK, rec.Code)
129+
}
130+
131+
data, err := ioutil.ReadAll(rec.Body)
132+
user := &gaia.User{}
133+
err = json.Unmarshal(data, user)
134+
if err != nil {
135+
t.Fatalf("error unmarshaling responce %v", err.Error())
136+
}
137+
token, _, err := new(jwt.Parser).ParseUnverified(user.Tokenstring, jwt.MapClaims{})
138+
if err != nil {
139+
t.Fatalf("error parsing the token %v", err.Error())
140+
}
141+
alg := "RS512"
142+
if token.Header["alg"] != alg {
143+
t.Fatalf("expected token alg %v got %v", alg, token.Header["alg"])
144+
}
145+
}

handlers/handler.go

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package handlers
22

33
import (
4-
"crypto/rand"
54
"errors"
65
"fmt"
76
"net/http"
87
"strings"
98

109
"github.com/GeertJohan/go.rice"
1110

11+
"crypto/rsa"
12+
1213
jwt "github.com/dgrijalva/jwt-go"
1314
"github.com/gaia-pipeline/gaia"
1415
scheduler "github.com/gaia-pipeline/gaia/scheduler"
@@ -48,22 +49,12 @@ var storeService *store.Store
4849

4950
var schedulerService *scheduler.Scheduler
5051

51-
// jwtKey is a random generated key for jwt signing
52-
var jwtKey []byte
53-
5452
// InitHandlers initializes(registers) all handlers
5553
func InitHandlers(e *echo.Echo, store *store.Store, scheduler *scheduler.Scheduler) error {
5654
// Set instances
5755
storeService = store
5856
schedulerService = scheduler
5957

60-
// Generate signing key for jwt
61-
jwtKey = make([]byte, 64)
62-
_, err := rand.Read(jwtKey)
63-
if err != nil {
64-
return err
65-
}
66-
6758
// Define prefix
6859
p := "/api/" + apiVersion + "/"
6960

@@ -142,13 +133,21 @@ func authBarrier(next echo.HandlerFunc) echo.HandlerFunc {
142133

143134
// Parse token
144135
token, err := jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
145-
// Validate signing method
146-
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
147-
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
136+
signingMethodError := fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
137+
switch token.Method.(type) {
138+
case *jwt.SigningMethodHMAC:
139+
if _, ok := gaia.Cfg.JWTKey.([]byte); !ok {
140+
return nil, signingMethodError
141+
}
142+
return gaia.Cfg.JWTKey, nil
143+
case *jwt.SigningMethodRSA:
144+
if _, ok := gaia.Cfg.JWTKey.(*rsa.PrivateKey); !ok {
145+
return nil, signingMethodError
146+
}
147+
return gaia.Cfg.JWTKey.(*rsa.PrivateKey).Public(), nil
148+
default:
149+
return nil, signingMethodError
148150
}
149-
150-
// return secret
151-
return jwtKey, nil
152151
})
153152
if err != nil {
154153
return c.String(http.StatusForbidden, err.Error())

0 commit comments

Comments
 (0)