Skip to content

Commit 8d64666

Browse files
committed
feat: add support for table column mapping
1 parent 1728137 commit 8d64666

File tree

5 files changed

+51
-9
lines changed

5 files changed

+51
-9
lines changed

expr.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (e *comparisonExpr) isComplete() bool {
7777

7878
// defaultValidateConvert will validate the comparison expr value, and then convert the
7979
// expr to its SQL equivalence.
80-
func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opt ...Option) (*WhereClause, error) {
80+
func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opts options) (*WhereClause, error) {
8181
const op = "mql.(comparisonExpr).convertToSql"
8282
switch {
8383
case columnName == "":
@@ -103,6 +103,12 @@ func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, column
103103
if err != nil {
104104
return nil, fmt.Errorf("%s: %q in %s: %w", op, *e.value, e.String(), ErrInvalidParameter)
105105
}
106+
newCol, ok := opts.withTableColumnMap[columnName]
107+
if ok {
108+
// override our column name with the mapped column name
109+
columnName = newCol
110+
}
111+
106112
if validator.typ == "time" {
107113
columnName = fmt.Sprintf("%s::date", columnName)
108114
}

expr_test.go

+15-5
Original file line numberDiff line numberDiff line change
@@ -93,39 +93,49 @@ func Test_defaultValidateConvert(t *testing.T) {
9393
t.Parallel()
9494
fValidators, err := fieldValidators(reflect.ValueOf(testModel{}))
9595
require.NoError(t, err)
96+
opts := getDefaultOptions()
9697
t.Run("missing-column", func(t *testing.T) {
97-
e, err := defaultValidateConvert("", EqualOp, pointer("alice"), fValidators["name"])
98+
e, err := defaultValidateConvert("", EqualOp, pointer("alice"), fValidators["name"], opts)
9899
require.Error(t, err)
99100
assert.Empty(t, e)
100101
assert.ErrorIs(t, err, ErrMissingColumn)
101102
assert.ErrorContains(t, err, "missing column")
102103
})
103104
t.Run("missing-comparison-op", func(t *testing.T) {
104-
e, err := defaultValidateConvert("name", "", pointer("alice"), fValidators["name"])
105+
e, err := defaultValidateConvert("name", "", pointer("alice"), fValidators["name"], opts)
105106
require.Error(t, err)
106107
assert.Empty(t, e)
107108
assert.ErrorIs(t, err, ErrMissingComparisonOp)
108109
assert.ErrorContains(t, err, "missing comparison operator")
109110
})
110111
t.Run("missing-value", func(t *testing.T) {
111-
e, err := defaultValidateConvert("name", EqualOp, nil, fValidators["name"])
112+
e, err := defaultValidateConvert("name", EqualOp, nil, fValidators["name"], opts)
112113
require.Error(t, err)
113114
assert.Empty(t, e)
114115
assert.ErrorIs(t, err, ErrMissingComparisonValue)
115116
assert.ErrorContains(t, err, "missing comparison value")
116117
})
117118
t.Run("missing-validator-func", func(t *testing.T) {
118-
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{typ: "string"})
119+
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{typ: "string"}, opts)
119120
require.Error(t, err)
120121
assert.Empty(t, e)
121122
assert.ErrorIs(t, err, ErrInvalidParameter)
122123
assert.ErrorContains(t, err, "missing validator function")
123124
})
124125
t.Run("missing-validator-typ", func(t *testing.T) {
125-
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn})
126+
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn}, opts)
126127
require.Error(t, err)
127128
assert.Empty(t, e)
128129
assert.ErrorIs(t, err, ErrInvalidParameter)
129130
assert.ErrorContains(t, err, "missing validator type")
130131
})
132+
t.Run("success-with-table-override", func(t *testing.T) {
133+
opts.withTableColumnMap["name"] = "users.name"
134+
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn, typ: "default"}, opts)
135+
assert.Empty(t, err)
136+
assert.NotEmpty(t, e)
137+
assert.Equal(t, "users.name=?", e.Condition, "condition")
138+
assert.Len(t, e.Args, 1, "args")
139+
assert.Equal(t, "alice", e.Args[0], "args[0]")
140+
})
131141
}

mql.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func exprToWhereClause(e expr, fValidators map[string]validator, opt ...Option)
8787
}
8888
return nil, fmt.Errorf("%s: %w %q %s", op, ErrInvalidColumn, columnName, cols)
8989
}
90-
w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opt...)
90+
w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opts)
9191
if err != nil {
9292
return nil, fmt.Errorf("%s: %w", op, err)
9393
}

mql_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,19 @@ func TestParse(t *testing.T) {
306306
wantErrIs: mql.ErrInvalidParameter,
307307
wantErrContains: "missing ConvertToSqlFunc: invalid parameter",
308308
},
309+
{
310+
name: "success-with-table-column-map",
311+
query: "custom_name=\"alice\"",
312+
model: testModel{},
313+
opts: []mql.Option{
314+
mql.WithColumnMap(map[string]string{"custom_name": "name"}),
315+
mql.WithTableColumnMap(map[string]string{"name": "users.custom->>'name'"}),
316+
},
317+
want: &mql.WhereClause{
318+
Condition: "users.custom->>'name'=?",
319+
Args: []any{"alice"},
320+
},
321+
},
309322
}
310323
for _, tc := range tests {
311324
tc := tc

options.go

+15-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ type options struct {
1313
withValidateConvertFns map[string]ValidateConvertFunc
1414
withIgnoredFields []string
1515
withPgPlaceholder bool
16+
withTableColumnMap map[string]string // map of model field names to their table.column name
1617
}
1718

1819
// Option - how options are passed as args
@@ -22,6 +23,7 @@ func getDefaultOptions() options {
2223
return options{
2324
withColumnMap: make(map[string]string),
2425
withValidateConvertFns: make(map[string]ValidateConvertFunc),
26+
withTableColumnMap: make(map[string]string),
2527
}
2628
}
2729

@@ -44,8 +46,8 @@ func withSkipWhitespace() Option {
4446
}
4547
}
4648

47-
// WithColumnMap provides an optional map of columns from a column in the user
48-
// provided query to a column in the database model
49+
// WithColumnMap provides an optional map of columns from the user
50+
// provided query to a field in the given model
4951
func WithColumnMap(m map[string]string) Option {
5052
return func(o *options) error {
5153
if !isNil(m) {
@@ -100,3 +102,14 @@ func WithPgPlaceholders() Option {
100102
return nil
101103
}
102104
}
105+
106+
// WithTableColumnMap provides an optional map of columns from the
107+
// model to the table.column name in the generated where clause
108+
func WithTableColumnMap(m map[string]string) Option {
109+
return func(o *options) error {
110+
if !isNil(m) {
111+
o.withTableColumnMap = m
112+
}
113+
return nil
114+
}
115+
}

0 commit comments

Comments
 (0)