This repository has been archived by the owner on Feb 26, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
column_list.go
193 lines (176 loc) · 4.72 KB
/
column_list.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
package sqlr
import (
"bytes"
"fmt"
"strings"
"github.com/jjeffery/sqlr/private/scanner"
)
// columnList represents a list of columns for use in an SQL clause.
//
// Each columnList represents a subset of the available columns.
// For example a column list for the WHERE clause in a row update
// statement will only contain the columns for the primary key.
type columnList struct {
allColumns []*Column
filter func(col *Column) bool
clause sqlClause
alias string
exclude map[string]struct{}
}
func newColumns(allColumns []*Column) columnList {
return columnList{
allColumns: allColumns,
clause: clauseSelectColumns,
}
}
// Parse parses the text inside the curly braces to obtain more information
// about how to render the column list. It is not very sophisticated at the moment,
// currently the only recognised values are:
// "alias n" => use alias "n" for each column in the list
// "pk" => primary key columns only
// "all" => all columns
func (cols columnList) Parse(clause sqlClause, text string) (columnList, error) {
cols2 := cols
cols2.clause = clause
cols2.filter = clause.defaultFilter()
// TODO: update filter based on text
scan := scanner.New(strings.NewReader(text))
scan.AddKeywords("alias", "all", "pk", "exclude")
scan.IgnoreWhiteSpace = true
if scan.Scan() {
needScan := false
for {
if needScan {
needScan = false
if !scan.Scan() {
break
}
}
tok, lit := scan.Token(), scan.Text()
if tok == scanner.EOF {
break
}
// TODO: dodgy job to get going quickly
if tok == scanner.KEYWORD {
switch strings.ToLower(lit) {
case "alias":
if scan.Scan() {
cols2.alias = scan.Text()
needScan = true
} else {
return columnList{}, fmt.Errorf("missing ident after 'alias'")
}
case "all":
cols2.filter = columnFilterAll
needScan = true
case "pk":
cols2.filter = columnFilterPK
needScan = true
case "exclude":
if scan.Scan() {
if cols2.exclude == nil {
cols2.exclude = make(map[string]struct{})
}
cols2.exclude[scan.Text()] = struct{}{}
for {
if !scan.Scan() {
break
}
if scan.Text() != "," {
break
}
if !scan.Scan() {
break
}
cols2.exclude[scan.Text()] = struct{}{}
}
} else {
return columnList{}, fmt.Errorf("missing column after 'exclude'")
}
}
} else {
needScan = true
}
}
}
if err := scan.Err(); err != nil {
return columnList{}, err
}
return cols2, nil
}
// String returns a string representation of the columns.
// The string returned depends on the SQL clause in which the
// columns appear.
func (cols columnList) String(dialect Dialect, counter func() int) string {
var buf bytes.Buffer
quotedColumnName := func(col *Column) string {
return dialect.Quote(col.Name())
}
placeholder := func() string {
return dialect.Placeholder(counter())
}
for i, col := range cols.filtered() {
if i > 0 {
if cols.clause.matchAny(
clauseUpdateWhere,
clauseDeleteWhere,
clauseSelectWhere) {
buf.WriteString(" and ")
} else {
buf.WriteString(", ")
}
}
switch cols.clause {
case clauseSelectColumns, clauseSelectOrderBy:
if cols.alias != "" {
buf.WriteString(cols.alias)
buf.WriteRune('.')
}
buf.WriteString(quotedColumnName(col))
case clauseInsertColumns:
buf.WriteString(quotedColumnName(col))
case clauseInsertValues:
buf.WriteString(placeholder())
case clauseUpdateSet, clauseUpdateWhere, clauseDeleteWhere, clauseSelectWhere:
if cols.alias != "" {
buf.WriteString(cols.alias)
buf.WriteRune('.')
}
buf.WriteString(quotedColumnName(col))
buf.WriteString(" = ")
buf.WriteString(placeholder())
}
}
return buf.String()
}
// filtered returns the columns after the filter has been applied
func (cols columnList) filtered() []*Column {
v := make([]*Column, 0, len(cols.allColumns))
for _, col := range cols.allColumns {
if cols.filter == nil || cols.filter(col) {
if _, ok := cols.exclude[col.columnName]; ok {
continue
}
v = append(v, col)
}
}
return v
}
// columnFilter is the filter for all columns
func columnFilterAll(col *Column) bool {
return true
}
// columnFilterPK is the filter for primary key columns only
func columnFilterPK(col *Column) bool {
return col.PrimaryKey()
}
// columnFilterInsertable is the filter for all columns except the autoincrement
// column (if it exists)
func columnFilterInsertable(col *Column) bool {
return !col.AutoIncrement()
}
// columnFitlerUpdateable is the filter for all columns not part of the primary key,
// and not autoincrement
func columnFilterUpdateable(col *Column) bool {
return !col.PrimaryKey() && !col.AutoIncrement()
}