Skip to content

Commit

Permalink
Fix #151.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Sep 16, 2024
1 parent 9638976 commit 06eaf41
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
28 changes: 27 additions & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ type rows struct {
*stmt
names []string
types []string
nulls []bool
}

func (r *rows) Close() error {
Expand All @@ -596,6 +597,22 @@ func (r *rows) Columns() []string {
return r.names
}

func (r *rows) loadTypes() {
if r.nulls == nil {
count := r.Stmt.ColumnCount()
r.nulls = make([]bool, count)
r.types = make([]string, count)
for i := range r.nulls {
if col := r.Stmt.ColumnOriginName(i); col != "" {
r.types[i], _, r.nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i),
col)
}
}
}
}

func (r *rows) declType(index int) string {
if r.types == nil {
count := r.Stmt.ColumnCount()
Expand All @@ -608,7 +625,8 @@ func (r *rows) declType(index int) string {
}

func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
decltype := r.declType(index)
r.loadTypes()
decltype := r.types[index]
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
decltype = decltype[:i]
Expand All @@ -617,6 +635,14 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
return strings.TrimSpace(decltype)
}

func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
r.loadTypes()
if r.nulls[index] {
return false, true
}
return true, false
}

func (r *rows) Next(dest []driver.Value) error {
old := r.Stmt.Conn().SetInterrupt(r.ctx)
defer r.Stmt.Conn().SetInterrupt(old)
Expand Down
14 changes: 10 additions & 4 deletions tests/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestDriver(t *testing.T) {
defer conn.Close()

res, err := conn.ExecContext(ctx,
`CREATE TABLE users (id INT, name VARCHAR(10))`)
`CREATE TABLE users (id INTEGER PRIMARY KEY NOT NULL, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -82,11 +82,17 @@ func TestDriver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got := typs[0].DatabaseTypeName(); got != "INT" {
t.Errorf("got %s, want INT", got)
if got := typs[0].DatabaseTypeName(); got != "INTEGER" {
t.Errorf("got %s, want INTEGER", got)
}
if got := typs[1].DatabaseTypeName(); got != "VARCHAR" {
t.Errorf("got %s, want INT", got)
t.Errorf("got %s, want VARCHAR", got)
}
if got, ok := typs[0].Nullable(); got || !ok {
t.Errorf("got %v/%v, want false/true", got, ok)
}
if got, ok := typs[1].Nullable(); !got || ok {
t.Errorf("got %v/%v, want true/false", got, ok)
}

row := 0
Expand Down

0 comments on commit 06eaf41

Please sign in to comment.