Skip to content

Commit 8101a5e

Browse files
committed
chore: add origin flag to config cors
1 parent b5893aa commit 8101a5e

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

bin/memos/main.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,19 @@ const (
3131
)
3232

3333
var (
34-
profile *_profile.Profile
35-
mode string
36-
addr string
37-
port int
38-
data string
39-
driver string
40-
dsn string
41-
serveFrontend bool
34+
profile *_profile.Profile
35+
mode string
36+
addr string
37+
port int
38+
data string
39+
driver string
40+
dsn string
41+
serveFrontend bool
42+
allowedOrigins []string
4243

4344
rootCmd = &cobra.Command{
4445
Use: "memos",
45-
Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`,
46+
Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
4647
Run: func(_cmd *cobra.Command, _args []string) {
4748
ctx, cancel := context.WithCancel(context.Background())
4849
dbDriver, err := db.NewDBDriver(profile)
@@ -114,6 +115,7 @@ func init() {
114115
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
115116
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
116117
rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files")
118+
rootCmd.PersistentFlags().StringArrayVarP(&allowedOrigins, "origins", "", []string{}, "CORS allowed domain origins")
117119

118120
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
119121
if err != nil {
@@ -143,12 +145,17 @@ func init() {
143145
if err != nil {
144146
panic(err)
145147
}
148+
err = viper.BindPFlag("origins", rootCmd.PersistentFlags().Lookup("origins"))
149+
if err != nil {
150+
panic(err)
151+
}
146152

147153
viper.SetDefault("mode", "demo")
148154
viper.SetDefault("driver", "sqlite")
149155
viper.SetDefault("addr", "")
150156
viper.SetDefault("port", 8081)
151157
viper.SetDefault("frontend", true)
158+
viper.SetDefault("origins", []string{})
152159
viper.SetEnvPrefix("memos")
153160
}
154161

server/profile/profile.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ type Profile struct {
3232
Version string `json:"version"`
3333
// Frontend indicate the frontend is enabled or not
3434
Frontend bool `json:"-"`
35+
// Origins is the list of allowed origins
36+
Origins []string `json:"-"`
3537
}
3638

3739
func (p *Profile) IsDev() bool {

server/server.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
4949
}
5050

5151
// Register CORS middleware.
52-
e.Use(CORSMiddleware())
52+
e.Use(CORSMiddleware(s.Profile.Origins))
5353

5454
serverID, err := s.getSystemServerID(ctx)
5555
if err != nil {
@@ -160,7 +160,7 @@ func grpcRequestSkipper(c echo.Context) bool {
160160
return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.")
161161
}
162162

163-
func CORSMiddleware() echo.MiddlewareFunc {
163+
func CORSMiddleware(origins []string) echo.MiddlewareFunc {
164164
return func(next echo.HandlerFunc) echo.HandlerFunc {
165165
return func(c echo.Context) error {
166166
if grpcRequestSkipper(c) {
@@ -170,7 +170,18 @@ func CORSMiddleware() echo.MiddlewareFunc {
170170
r := c.Request()
171171
w := c.Response().Writer
172172

173-
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
173+
requestOrigin := r.Header.Get("Origin")
174+
if len(origins) == 0 {
175+
w.Header().Set("Access-Control-Allow-Origin", requestOrigin)
176+
} else {
177+
for _, origin := range origins {
178+
if origin == requestOrigin {
179+
w.Header().Set("Access-Control-Allow-Origin", origin)
180+
break
181+
}
182+
}
183+
}
184+
174185
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
175186
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
176187
w.Header().Set("Access-Control-Allow-Credentials", "true")

0 commit comments

Comments
 (0)