Skip to content

Commit e789c23

Browse files
committed
Persist signing cert and key in database and allow custom
1 parent 0d78000 commit e789c23

File tree

4 files changed

+138
-7
lines changed

4 files changed

+138
-7
lines changed

api/external.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ func (a *API) Provider(ctx context.Context, name string) (provider.Provider, err
277277
case "facebook":
278278
return provider.NewFacebookProvider(config.External.Facebook)
279279
case "saml":
280-
return provider.NewSamlProvider(config.External.Saml)
280+
return provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx))
281281
default:
282282
return nil, fmt.Errorf("Provider %s could not be found", name)
283283
}

api/external_saml.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func (a *API) loadSAMLState(w http.ResponseWriter, r *http.Request) (context.Con
2222
func (a *API) samlCallback(r *http.Request, ctx context.Context) (*provider.UserProvidedData, error) {
2323
config := a.getConfig(ctx)
2424

25-
samlProvider, err := provider.NewSamlProvider(config.External.Saml)
25+
samlProvider, err := provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx))
2626
if err != nil {
2727
return nil, badRequestError("Could not initialize SAML provider: %+v", err).WithInternalError(err)
2828
}
@@ -59,7 +59,7 @@ func (a *API) SAMLMetadata(w http.ResponseWriter, r *http.Request) error {
5959
ctx := r.Context()
6060
config := getConfig(ctx)
6161

62-
samlProvider, err := provider.NewSamlProvider(config.External.Saml)
62+
samlProvider, err := provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx))
6363
if err != nil {
6464
return internalServerError("Could not create SAML Provider: %+v", err).WithInternalError(err)
6565
}

api/provider/saml.go

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,47 @@
11
package provider
22

33
import (
4+
"crypto/rand"
5+
"crypto/rsa"
6+
"crypto/tls"
47
"crypto/x509"
58
"encoding/base64"
9+
"encoding/pem"
610
"encoding/xml"
711
"errors"
812
"fmt"
913
"io/ioutil"
14+
"math/big"
1015
"net/http"
1116
"net/url"
1217
"strings"
18+
"time"
19+
20+
"github.com/netlify/gotrue/models"
21+
"github.com/netlify/gotrue/storage"
1322

1423
"github.com/netlify/gotrue/conf"
1524
saml2 "github.com/russellhaering/gosaml2"
1625
"github.com/russellhaering/gosaml2/types"
1726
dsig "github.com/russellhaering/goxmldsig"
27+
uuid "github.com/satori/go.uuid"
1828
"golang.org/x/oauth2"
1929
)
2030

2131
type SamlProvider struct {
2232
ServiceProvider *saml2.SAMLServiceProvider
2333
}
2434

35+
type SamlCertCreation struct {
36+
instanceId uuid.UUID
37+
db *storage.Connection
38+
}
39+
40+
type MemoryX509KeyStore struct {
41+
privateKey *rsa.PrivateKey
42+
cert []byte
43+
}
44+
2545
func getMetadata(url string) (*types.EntityDescriptor, error) {
2646
res, err := http.Get(url)
2747
if err != nil {
@@ -45,7 +65,7 @@ func getMetadata(url string) (*types.EntityDescriptor, error) {
4565
}
4666

4767
// NewSamlProvider creates a Saml account provider.
48-
func NewSamlProvider(ext conf.SamlProviderConfiguration) (*SamlProvider, error) {
68+
func NewSamlProvider(ext conf.SamlProviderConfiguration, db *storage.Connection, instanceId uuid.UUID) (*SamlProvider, error) {
4969
if !ext.Enabled {
5070
return nil, errors.New("SAML Provider is not enabled")
5171
}
@@ -100,8 +120,15 @@ func NewSamlProvider(ext conf.SamlProviderConfiguration) (*SamlProvider, error)
100120
}
101121
}
102122

103-
// TODO: generate keys once, save them in the database and use here
104-
randomKeyStore := dsig.RandomKeyStoreForTest()
123+
certCreation := &SamlCertCreation{
124+
instanceId: instanceId,
125+
db: db,
126+
}
127+
128+
err, keyStore := certCreation.PrepareKeystore(ext)
129+
if err != nil {
130+
return nil, err
131+
}
105132

106133
sp := &saml2.SAMLServiceProvider{
107134
IdentityProviderSSOURL: ssoService.Location,
@@ -111,7 +138,7 @@ func NewSamlProvider(ext conf.SamlProviderConfiguration) (*SamlProvider, error)
111138
SignAuthnRequests: true,
112139
AudienceURI: baseURI.String() + "/saml",
113140
IDPCertificateStore: &certStore,
114-
SPKeyStore: randomKeyStore,
141+
SPKeyStore: keyStore,
115142
AllowMissingAttributes: true,
116143
}
117144

@@ -142,3 +169,105 @@ func (p SamlProvider) SPMetadata() ([]byte, error) {
142169

143170
return rawMetadata, nil
144171
}
172+
173+
func (s SamlCertCreation) PrepareKeystore(conf conf.SamlProviderConfiguration) (error, dsig.X509KeyStore) {
174+
if conf.SigningCert == "" && conf.SigningKey == "" {
175+
return s.CreateSigningCert()
176+
}
177+
178+
keyPair, err := tls.X509KeyPair([]byte(conf.SigningCert), []byte(conf.SigningKey))
179+
if err != nil {
180+
return fmt.Errorf("Parsing key pair failed: %+v", err), nil
181+
}
182+
183+
var privKey *rsa.PrivateKey
184+
switch key := keyPair.PrivateKey.(type) {
185+
case *rsa.PrivateKey:
186+
privKey = key
187+
default:
188+
return errors.New("Private key is not an RSA key"), nil
189+
}
190+
191+
return nil, &MemoryX509KeyStore{
192+
privateKey: privKey,
193+
cert: keyPair.Certificate[0],
194+
}
195+
}
196+
197+
func (s SamlCertCreation) CreateSigningCert() (error, dsig.X509KeyStore) {
198+
key, err := rsa.GenerateKey(rand.Reader, 2048)
199+
if err != nil {
200+
return err, nil
201+
}
202+
203+
currentTime := time.Now()
204+
205+
certBody := &x509.Certificate{
206+
SerialNumber: big.NewInt(1),
207+
NotBefore: currentTime.Add(-5 * time.Minute),
208+
NotAfter: currentTime.Add(365 * 24 * time.Hour),
209+
210+
KeyUsage: x509.KeyUsageDigitalSignature,
211+
ExtKeyUsage: []x509.ExtKeyUsage{},
212+
BasicConstraintsValid: true,
213+
}
214+
215+
cert, err := x509.CreateCertificate(rand.Reader, certBody, certBody, &key.PublicKey, key)
216+
if err != nil {
217+
return fmt.Errorf("Failed to create certificate: %+v", err), nil
218+
}
219+
220+
if err := s.SaveConfig(cert, key); err != nil {
221+
return fmt.Errorf("Saving signing keypair failed: %+v", err), nil
222+
}
223+
224+
return nil, &MemoryX509KeyStore{
225+
privateKey: key,
226+
cert: cert,
227+
}
228+
}
229+
230+
func (s SamlCertCreation) SaveConfig(cert []byte, key *rsa.PrivateKey) error {
231+
if uuid.Equal(s.instanceId, uuid.Nil) {
232+
return nil
233+
}
234+
235+
pemCert := &pem.Block{
236+
Type: "CERTIFICATE",
237+
Bytes: cert,
238+
}
239+
240+
certBytes := pem.EncodeToMemory(pemCert)
241+
if certBytes == nil {
242+
return errors.New("Could not encode certificate")
243+
}
244+
245+
pemKey := &pem.Block{
246+
Type: "PRIVATE KEY",
247+
Bytes: x509.MarshalPKCS1PrivateKey(key),
248+
}
249+
250+
keyBytes := pem.EncodeToMemory(pemKey)
251+
if keyBytes == nil {
252+
return errors.New("Could not encode key")
253+
}
254+
255+
instance, err := models.GetInstance(s.db, s.instanceId)
256+
if err != nil {
257+
return err
258+
}
259+
260+
conf := instance.BaseConfig
261+
conf.External.Saml.SigningCert = string(certBytes)
262+
conf.External.Saml.SigningKey = string(keyBytes)
263+
264+
if err := instance.UpdateConfig(s.db, conf); err != nil {
265+
return err
266+
}
267+
268+
return nil
269+
}
270+
271+
func (ks *MemoryX509KeyStore) GetKeyPair() (*rsa.PrivateKey, []byte, error) {
272+
return ks.privateKey, ks.cert, nil
273+
}

conf/configuration.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ type SamlProviderConfiguration struct {
3030
MetadataURL string `json:"metadata_url" envconfig:"METADATA_URL"`
3131
APIBase string `json:"api_base" envconfig:"API_BASE"`
3232
Name string `json:"name"`
33+
SigningCert string `json:"signing_cert" envconfig:"SIGNING_CERT"`
34+
SigningKey string `json:"signing_key" envconfig:"SIGNING_KEY"`
3335
}
3436

3537
// DBConfiguration holds all the database related configuration.

0 commit comments

Comments
 (0)