Skip to content
This repository was archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Commit cac7688

Browse files
committed
fix quote policy
1 parent ecc286a commit cac7688

9 files changed

+44
-83
lines changed

Diff for: engine.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ type Engine struct {
5656

5757
defaultContext context.Context
5858

59-
quotePolicy QuotePolicy
60-
quoteMode QuoteMode
59+
colQuoter Quoter
60+
tableQuoter Quoter
6161
}
6262

6363
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
@@ -419,7 +419,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
419419
return err
420420
}
421421

422-
quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy)
422+
colQuoter := newQuoter(dialect, engine.colQuoter.QuotePolicy())
423423

424424
for i, table := range tables {
425425
if i > 0 {
@@ -440,8 +440,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
440440
}
441441

442442
cols := table.ColumnsSeq()
443-
colNames := quoteJoin(engine, cols)
444-
destColNames := quoteJoin(quoter, cols)
443+
colNames := quoteJoin(engine.colQuoter, cols)
444+
destColNames := quoteJoin(colQuoter, cols)
445445

446446
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false))
447447
if err != nil {

Diff for: engine_quote.go

+16-60
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,21 @@ const (
2121
QuoteAddReserved
2222
)
2323

24-
// QuoteMode quote on which types
25-
type QuoteMode int
26-
27-
// All QuoteModes
28-
const (
29-
QuoteTableAndColumns QuoteMode = iota
30-
QuoteTableOnly
31-
QuoteColumnsOnly
32-
)
33-
3424
// Quoter represents an object has Quote method
3525
type Quoter interface {
3626
Quotes() (byte, byte)
3727
QuotePolicy() QuotePolicy
38-
QuoteMode() QuoteMode
3928
IsReserved(string) bool
4029
}
4130

4231
type quoter struct {
4332
dialect core.Dialect
44-
quoteMode QuoteMode
4533
quotePolicy QuotePolicy
4634
}
4735

48-
func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter {
36+
func newQuoter(dialect core.Dialect, quotePolicy QuotePolicy) Quoter {
4937
return &quoter{
5038
dialect: dialect,
51-
quoteMode: quoteMode,
5239
quotePolicy: quotePolicy,
5340
}
5441
}
@@ -62,10 +49,6 @@ func (q *quoter) QuotePolicy() QuotePolicy {
6249
return q.quotePolicy
6350
}
6451

65-
func (q *quoter) QuoteMode() QuoteMode {
66-
return q.quoteMode
67-
}
68-
6952
func (q *quoter) IsReserved(value string) bool {
7053
return q.dialect.IsReserved(value)
7154
}
@@ -77,21 +60,24 @@ func quoteColumns(quoter Quoter, columnStr string) string {
7760

7861
func quoteJoin(quoter Quoter, columns []string) string {
7962
for i := 0; i < len(columns); i++ {
80-
columns[i] = quote(quoter, columns[i], true)
63+
columns[i] = quote(quoter, columns[i])
8164
}
8265
return strings.Join(columns, ",")
8366
}
8467

8568
// quote Use QuoteStr quote the string sql
86-
func quote(quoter Quoter, value string, isColumn bool) string {
69+
func quote(quoter Quoter, value string) string {
8770
buf := strings.Builder{}
88-
quoteTo(quoter, &buf, value, isColumn)
71+
quoteTo(quoter, &buf, value)
8972
return buf.String()
9073
}
9174

9275
// Quote add quotes to the value
9376
func (engine *Engine) quote(value string, isColumn bool) string {
94-
return quote(engine, value, isColumn)
77+
if isColumn {
78+
return quote(engine.colQuoter, value)
79+
}
80+
return quote(engine.tableQuoter, value)
9581
}
9682

9783
// Quote add quotes to the value
@@ -105,53 +91,25 @@ func (engine *Engine) Quotes() (byte, byte) {
10591
return quotes[0], quotes[1]
10692
}
10793

108-
// QuoteMode returns quote mode
109-
func (engine *Engine) QuoteMode() QuoteMode {
110-
return engine.quoteMode
111-
}
112-
113-
// QuotePolicy returns quote policy
114-
func (engine *Engine) QuotePolicy() QuotePolicy {
115-
return engine.quotePolicy
116-
}
117-
11894
// IsReserved return true if the value is a reserved word of the database
11995
func (engine *Engine) IsReserved(value string) bool {
12096
return engine.dialect.IsReserved(value)
12197
}
12298

12399
// quoteTo quotes string and writes into the buffer
124-
func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) {
125-
if isColumn {
126-
if quoter.QuoteMode() == QuoteTableAndColumns ||
127-
quoter.QuoteMode() == QuoteColumnsOnly {
128-
if quoter.QuotePolicy() == QuoteAddAlways {
129-
realQuoteTo(quoter, buf, value)
130-
return
131-
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
132-
realQuoteTo(quoter, buf, value)
133-
return
134-
}
135-
}
136-
buf.WriteString(value)
100+
func quoteTo(quoter Quoter, buf *strings.Builder, value string) {
101+
left, right := quoter.Quotes()
102+
if quoter.QuotePolicy() == QuoteAddAlways {
103+
realQuoteTo(left, right, buf, value)
104+
return
105+
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
106+
realQuoteTo(left, right, buf, value)
137107
return
138-
}
139-
140-
if quoter.QuoteMode() == QuoteTableAndColumns ||
141-
quoter.QuoteMode() == QuoteTableOnly {
142-
if quoter.QuotePolicy() == QuoteAddAlways {
143-
realQuoteTo(quoter, buf, value)
144-
return
145-
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
146-
realQuoteTo(quoter, buf, value)
147-
return
148-
}
149108
}
150109
buf.WriteString(value)
151-
return
152110
}
153111

154-
func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
112+
func realQuoteTo(quoteLeft, quoteRight byte, buf *strings.Builder, value string) {
155113
if buf == nil {
156114
return
157115
}
@@ -164,8 +122,6 @@ func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
164122
return
165123
}
166124

167-
quoteLeft, quoteRight := quoter.Quotes()
168-
169125
if value[0] == '`' || value[0] == quoteLeft { // no quote
170126
_, _ = buf.WriteString(value)
171127
return

Diff for: session_find.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
141141
if session.statement.JoinStr == "" {
142142
if columnStr == "" {
143143
if session.statement.GroupByStr != "" {
144-
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
144+
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
145145
} else {
146146
columnStr = session.statement.genColumnStr()
147147
}
148148
}
149149
} else {
150150
if columnStr == "" {
151151
if session.statement.GroupByStr != "" {
152-
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
152+
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
153153
} else {
154154
columnStr = "*"
155155
}

Diff for: session_insert.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
249249
if session.engine.dialect.DBType() == core.ORACLE {
250250
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
251251
session.engine.quote(tableName, false),
252-
quoteJoin(session.engine, colNames))
252+
quoteJoin(session.engine.colQuoter, colNames))
253253
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
254254
session.engine.quote(tableName, false),
255-
quoteJoin(session.engine, colNames),
255+
quoteJoin(session.engine.colQuoter, colNames),
256256
strings.Join(colMultiPlaces, temp))
257257
} else {
258258
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
259259
session.engine.quote(tableName, false),
260-
quoteJoin(session.engine, colNames),
260+
quoteJoin(session.engine.colQuoter, colNames),
261261
strings.Join(colMultiPlaces, "),("))
262262
}
263263
res, err := session.exec(sql, args...)
@@ -855,7 +855,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
855855

856856
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
857857
session.engine.quote(tableName, false),
858-
quoteJoin(session.engine, columns), qm)); err != nil {
858+
quoteJoin(session.engine.colQuoter, columns), qm)); err != nil {
859859
return 0, err
860860
}
861861
w.Append(args...)

Diff for: session_query.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
3535
if session.statement.JoinStr == "" {
3636
if columnStr == "" {
3737
if session.statement.GroupByStr != "" {
38-
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
38+
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
3939
} else {
4040
columnStr = session.statement.genColumnStr()
4141
}
4242
}
4343
} else {
4444
if columnStr == "" {
4545
if session.statement.GroupByStr != "" {
46-
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
46+
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
4747
} else {
4848
columnStr = "*"
4949
}

Diff for: session_update.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
100100
for idx, kv := range kvs {
101101
sps := strings.SplitN(kv, "=", 2)
102102
sps2 := strings.Split(sps[0], ".")
103-
colName := unQuote(session.engine, sps2[len(sps2)-1])
103+
colName := unQuote(session.engine.colQuoter, sps2[len(sps2)-1])
104104

105105
if col := table.GetColumn(colName); col != nil {
106106
fieldValue, err := col.ValueOf(bean)

Diff for: statement.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
603603

604604
newColumns := statement.colmap2NewColsWithQuote()
605605

606-
statement.ColumnStr = quoteJoin(statement.Engine, newColumns)
606+
statement.ColumnStr = quoteJoin(statement.Engine.colQuoter, newColumns)
607607
return statement
608608
}
609609

@@ -638,7 +638,7 @@ func (statement *Statement) Omit(columns ...string) {
638638
for _, nc := range newColumns {
639639
statement.omitColumnMap = append(statement.omitColumnMap, nc)
640640
}
641-
statement.OmitStr = quoteJoin(statement.Engine, newColumns)
641+
statement.OmitStr = quoteJoin(statement.Engine.colQuoter, newColumns)
642642
}
643643

644644
// Nullable Update use only: update columns to null when value is nullable and zero-value
@@ -732,7 +732,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
732732
}
733733
tbs := strings.Split(tp.TableName(), ".")
734734

735-
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
735+
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
736736
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
737737
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
738738
case *builder.Builder:
@@ -743,7 +743,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
743743
}
744744
tbs := strings.Split(tp.TableName(), ".")
745745

746-
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
746+
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
747747
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
748748
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
749749
default:
@@ -809,7 +809,7 @@ func (statement *Statement) genColumnStr() string {
809809
buf.WriteString(".")
810810
}
811811

812-
quoteTo(statement.Engine, &buf, col.Name, true)
812+
quoteTo(statement.Engine.colQuoter, &buf, col.Name)
813813
}
814814

815815
return buf.String()
@@ -928,15 +928,15 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
928928
if len(statement.JoinStr) == 0 {
929929
if len(columnStr) == 0 {
930930
if len(statement.GroupByStr) > 0 {
931-
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
931+
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
932932
} else {
933933
columnStr = statement.genColumnStr()
934934
}
935935
}
936936
} else {
937937
if len(columnStr) == 0 {
938938
if len(statement.GroupByStr) > 0 {
939-
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
939+
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
940940
}
941941
}
942942
}

Diff for: statement_test.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ func TestCol2NewColsWithQuote(t *testing.T) {
243243

244244
statement := createTestStatement()
245245

246-
quotedCols := quoteJoin(statement.Engine, cols)
247-
assert.EqualValues(t, []string{statement.Engine.Quote("f1", true), statement.Engine.Quote("f2", true), statement.Engine.Quote("t3.f3", true)}, quotedCols)
246+
quotedCols := quoteJoin(statement.Engine.colQuoter, cols)
247+
assert.EqualValues(t, statement.Engine.Quote("f1", true)+","+
248+
statement.Engine.Quote("f2", true)+","+
249+
statement.Engine.Quote("t3.f3", true),
250+
quotedCols)
248251
}

Diff for: xorm.go

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
9595
tagHandlers: defaultTagHandlers,
9696
cachers: make(map[string]core.Cacher),
9797
defaultContext: context.Background(),
98+
colQuoter: newQuoter(dialect, QuoteAddAlways),
99+
tableQuoter: newQuoter(dialect, QuoteAddAlways),
98100
}
99101

100102
if uri.DbType == core.SQLITE {

0 commit comments

Comments
 (0)