1
1
package provider
2
2
3
3
import (
4
+ "crypto/rand"
5
+ "crypto/rsa"
6
+ "crypto/tls"
4
7
"crypto/x509"
5
8
"encoding/base64"
9
+ "encoding/pem"
6
10
"encoding/xml"
7
11
"errors"
8
12
"fmt"
9
13
"io/ioutil"
14
+ "math/big"
10
15
"net/http"
11
16
"net/url"
12
17
"strings"
18
+ "time"
19
+
20
+ "github.com/netlify/gotrue/models"
21
+ "github.com/netlify/gotrue/storage"
13
22
14
23
"github.com/netlify/gotrue/conf"
15
24
saml2 "github.com/russellhaering/gosaml2"
16
25
"github.com/russellhaering/gosaml2/types"
17
26
dsig "github.com/russellhaering/goxmldsig"
27
+ uuid "github.com/satori/go.uuid"
18
28
"golang.org/x/oauth2"
19
29
)
20
30
21
31
type SamlProvider struct {
22
32
ServiceProvider * saml2.SAMLServiceProvider
23
33
}
24
34
35
+ type SamlCertCreation struct {
36
+ instanceId uuid.UUID
37
+ db * storage.Connection
38
+ }
39
+
40
+ type MemoryX509KeyStore struct {
41
+ privateKey * rsa.PrivateKey
42
+ cert []byte
43
+ }
44
+
25
45
func getMetadata (url string ) (* types.EntityDescriptor , error ) {
26
46
res , err := http .Get (url )
27
47
if err != nil {
@@ -45,7 +65,7 @@ func getMetadata(url string) (*types.EntityDescriptor, error) {
45
65
}
46
66
47
67
// NewSamlProvider creates a Saml account provider.
48
- func NewSamlProvider (ext conf.SamlProviderConfiguration ) (* SamlProvider , error ) {
68
+ func NewSamlProvider (ext conf.SamlProviderConfiguration , db * storage. Connection , instanceId uuid. UUID ) (* SamlProvider , error ) {
49
69
if ! ext .Enabled {
50
70
return nil , errors .New ("SAML Provider is not enabled" )
51
71
}
@@ -100,8 +120,15 @@ func NewSamlProvider(ext conf.SamlProviderConfiguration) (*SamlProvider, error)
100
120
}
101
121
}
102
122
103
- // TODO: generate keys once, save them in the database and use here
104
- randomKeyStore := dsig .RandomKeyStoreForTest ()
123
+ certCreation := & SamlCertCreation {
124
+ instanceId : instanceId ,
125
+ db : db ,
126
+ }
127
+
128
+ err , keyStore := certCreation .PrepareKeystore (ext )
129
+ if err != nil {
130
+ return nil , err
131
+ }
105
132
106
133
sp := & saml2.SAMLServiceProvider {
107
134
IdentityProviderSSOURL : ssoService .Location ,
@@ -111,7 +138,7 @@ func NewSamlProvider(ext conf.SamlProviderConfiguration) (*SamlProvider, error)
111
138
SignAuthnRequests : true ,
112
139
AudienceURI : baseURI .String () + "/saml" ,
113
140
IDPCertificateStore : & certStore ,
114
- SPKeyStore : randomKeyStore ,
141
+ SPKeyStore : keyStore ,
115
142
AllowMissingAttributes : true ,
116
143
}
117
144
@@ -142,3 +169,105 @@ func (p SamlProvider) SPMetadata() ([]byte, error) {
142
169
143
170
return rawMetadata , nil
144
171
}
172
+
173
+ func (s SamlCertCreation ) PrepareKeystore (conf conf.SamlProviderConfiguration ) (error , dsig.X509KeyStore ) {
174
+ if conf .SigningCert == "" && conf .SigningKey == "" {
175
+ return s .CreateSigningCert ()
176
+ }
177
+
178
+ keyPair , err := tls .X509KeyPair ([]byte (conf .SigningCert ), []byte (conf .SigningKey ))
179
+ if err != nil {
180
+ return fmt .Errorf ("Parsing key pair failed: %+v" , err ), nil
181
+ }
182
+
183
+ var privKey * rsa.PrivateKey
184
+ switch key := keyPair .PrivateKey .(type ) {
185
+ case * rsa.PrivateKey :
186
+ privKey = key
187
+ default :
188
+ return errors .New ("Private key is not an RSA key" ), nil
189
+ }
190
+
191
+ return nil , & MemoryX509KeyStore {
192
+ privateKey : privKey ,
193
+ cert : keyPair .Certificate [0 ],
194
+ }
195
+ }
196
+
197
+ func (s SamlCertCreation ) CreateSigningCert () (error , dsig.X509KeyStore ) {
198
+ key , err := rsa .GenerateKey (rand .Reader , 2048 )
199
+ if err != nil {
200
+ return err , nil
201
+ }
202
+
203
+ currentTime := time .Now ()
204
+
205
+ certBody := & x509.Certificate {
206
+ SerialNumber : big .NewInt (1 ),
207
+ NotBefore : currentTime .Add (- 5 * time .Minute ),
208
+ NotAfter : currentTime .Add (365 * 24 * time .Hour ),
209
+
210
+ KeyUsage : x509 .KeyUsageDigitalSignature ,
211
+ ExtKeyUsage : []x509.ExtKeyUsage {},
212
+ BasicConstraintsValid : true ,
213
+ }
214
+
215
+ cert , err := x509 .CreateCertificate (rand .Reader , certBody , certBody , & key .PublicKey , key )
216
+ if err != nil {
217
+ return fmt .Errorf ("Failed to create certificate: %+v" , err ), nil
218
+ }
219
+
220
+ if err := s .SaveConfig (cert , key ); err != nil {
221
+ return fmt .Errorf ("Saving signing keypair failed: %+v" , err ), nil
222
+ }
223
+
224
+ return nil , & MemoryX509KeyStore {
225
+ privateKey : key ,
226
+ cert : cert ,
227
+ }
228
+ }
229
+
230
+ func (s SamlCertCreation ) SaveConfig (cert []byte , key * rsa.PrivateKey ) error {
231
+ if uuid .Equal (s .instanceId , uuid .Nil ) {
232
+ return nil
233
+ }
234
+
235
+ pemCert := & pem.Block {
236
+ Type : "CERTIFICATE" ,
237
+ Bytes : cert ,
238
+ }
239
+
240
+ certBytes := pem .EncodeToMemory (pemCert )
241
+ if certBytes == nil {
242
+ return errors .New ("Could not encode certificate" )
243
+ }
244
+
245
+ pemKey := & pem.Block {
246
+ Type : "PRIVATE KEY" ,
247
+ Bytes : x509 .MarshalPKCS1PrivateKey (key ),
248
+ }
249
+
250
+ keyBytes := pem .EncodeToMemory (pemKey )
251
+ if keyBytes == nil {
252
+ return errors .New ("Could not encode key" )
253
+ }
254
+
255
+ instance , err := models .GetInstance (s .db , s .instanceId )
256
+ if err != nil {
257
+ return err
258
+ }
259
+
260
+ conf := instance .BaseConfig
261
+ conf .External .Saml .SigningCert = string (certBytes )
262
+ conf .External .Saml .SigningKey = string (keyBytes )
263
+
264
+ if err := instance .UpdateConfig (s .db , conf ); err != nil {
265
+ return err
266
+ }
267
+
268
+ return nil
269
+ }
270
+
271
+ func (ks * MemoryX509KeyStore ) GetKeyPair () (* rsa.PrivateKey , []byte , error ) {
272
+ return ks .privateKey , ks .cert , nil
273
+ }
0 commit comments