Skip to content

Commit

Permalink
Correctly handle duplicate column names in sqlite joins (#10285)
Browse files Browse the repository at this point in the history
* add tests

* working

* cleanup

* fix compile

* fix naming and comment

* fix lints in test

* apply suggested fixes
  • Loading branch information
gvilums committed Apr 15, 2024
1 parent 3f10d52 commit 24a411f
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 16 deletions.
2 changes: 0 additions & 2 deletions packages/bun-usockets/src/crypto/openssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ void us_internal_on_ssl_handshake(
struct us_internal_ssl_socket_t *
us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code,
void *reason) {
struct us_internal_ssl_socket_context_t *context =
(struct us_internal_ssl_socket_context_t *)us_socket_context(0, &s->s);

if (s->handshake_state != HANDSHAKE_COMPLETED) {
// if we have some pending handshake we cancel it and try to check the
Expand Down
109 changes: 99 additions & 10 deletions src/bun.js/bindings/sqlite/JSSQLStatement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include <JavaScriptCore/ObjectPrototype.h>
#include "BunBuiltinNames.h"
#include "sqlite3_error_codes.h"
#include "wtf/BitVector.h"
#include "wtf/Vector.h"
#include <atomic>

/* ******************************************************************************** */
Expand Down Expand Up @@ -322,6 +324,9 @@ class JSSQLStatement : public JSC::JSDestructibleObject {
VersionSqlite3* version_db;
uint64_t version;
bool hasExecuted = false;
// Tracks which columns are valid in the current result set. Used to handle duplicate column names.
// The bit at index i is set if the column at index i is valid.
WTF::BitVector validColumns;
std::unique_ptr<PropertyNameArray> columnNames;
mutable JSC::WriteBarrier<JSC::JSObject> _prototype;
mutable JSC::WriteBarrier<JSC::Structure> _structure;
Expand Down Expand Up @@ -462,6 +467,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ
castedThis->columnNames->propertyNameMode(),
castedThis->columnNames->privateSymbolMode()));
}
castedThis->validColumns.clearAll();
castedThis->update_version();

JSC::VM& vm = lexicalGlobalObject->vm();
Expand All @@ -484,7 +490,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ

auto columnNames = castedThis->columnNames.get();
bool anyHoles = false;
for (int i = 0; i < count; i++) {
for (int i = count - 1; i >= 0; i--) {
const char* name = sqlite3_column_name(stmt, i);

if (name == nullptr) {
Expand All @@ -498,14 +504,29 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ
break;
}

columnNames->add(Identifier::fromString(vm, WTF::String::fromUTF8({ name, len })));
// When joining multiple tables, the same column names can appear multiple times
// columnNames de-dupes property names internally
// We can't have two properties with the same name, so we use validColumns to track this.
auto preCount = columnNames->size();
columnNames->add(
Identifier::fromString(vm, WTF::String::fromUTF8({name, len}))
);
auto curCount = columnNames->size();

if (preCount != curCount) {
castedThis->validColumns.set(i);
}
}

if (LIKELY(!anyHoles)) {
PropertyOffset offset;
Structure* structure = globalObject.structureCache().emptyObjectStructureForPrototype(&globalObject, globalObject.objectPrototype(), columnNames->size());
vm.writeBarrier(castedThis, structure);

// We iterated over the columns in reverse order so we need to reverse the columnNames here
// Importantly we reverse before adding the properties to the structure to ensure that index accesses
// later refer to the correct property.
columnNames->data()->propertyNameVector().reverse();
for (const auto& propertyName : *columnNames) {
structure = Structure::addPropertyTransition(vm, structure, propertyName, 0, offset);
}
Expand All @@ -520,6 +541,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ
castedThis->columnNames->vm(),
castedThis->columnNames->propertyNameMode(),
castedThis->columnNames->privateSymbolMode()));
castedThis->validColumns.clearAll();
}
}

Expand All @@ -531,7 +553,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ
// see https://github.com/oven-sh/bun/issues/987
JSC::JSObject* object = JSC::constructEmptyObject(lexicalGlobalObject, lexicalGlobalObject->objectPrototype(), std::min(static_cast<unsigned>(count), JSFinalObject::maxInlineCapacity));

for (int i = 0; i < count; i++) {
for (int i = count - 1; i >= 0; i--) {
const char* name = sqlite3_column_name(stmt, i);

if (name == nullptr)
Expand Down Expand Up @@ -562,9 +584,18 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ
}
}

object->putDirect(vm, key, primitive, 0);
auto preCount = castedThis->columnNames->size();
castedThis->columnNames->add(key);
auto curCount = castedThis->columnNames->size();

// only put the property if it's not a duplicate
if (preCount != curCount) {
castedThis->validColumns.set(i);
object->putDirect(vm, key, primitive, 0);
}
}
// We iterated over the columns in reverse order so we need to reverse the columnNames here
castedThis->columnNames->data()->propertyNameVector().reverse();
castedThis->_prototype.set(vm, castedThis, object);
}

Expand Down Expand Up @@ -1495,8 +1526,15 @@ static inline JSC::JSValue constructResultObject(JSC::JSGlobalObject* lexicalGlo
if (auto* structure = castedThis->_structure.get()) {
result = JSC::constructEmptyObject(vm, structure);

for (unsigned int i = 0; i < count; i++) {
result->putDirectOffset(vm, i, toJS(vm, lexicalGlobalObject, stmt, i));
// i: the index of columns returned from SQLite
// j: the index of object property
for (int i = 0, j = 0; j < count; i++, j++) {
if (!castedThis->validColumns.get(i)) {
// this column is duplicate, skip
j -= 1;
continue;
}
result->putDirectOffset(vm, j, toJS(vm, lexicalGlobalObject, stmt, i));
}

} else {
Expand All @@ -1506,9 +1544,56 @@ static inline JSC::JSValue constructResultObject(JSC::JSGlobalObject* lexicalGlo
result = JSC::JSFinalObject::create(vm, JSC::JSFinalObject::createStructure(vm, lexicalGlobalObject, lexicalGlobalObject->objectPrototype(), JSFinalObject::maxInlineCapacity));
}

for (int i = 0; i < count; i++) {
const auto& name = columnNames[i];
for (int i = 0, j = 0; j < count; i++, j++) {
if (!castedThis->validColumns.get(i)) {
j -= 1;
continue;
}
auto name = columnNames[j];
result->putDirect(vm, name, toJS(vm, lexicalGlobalObject, stmt, i), 0);

switch (sqlite3_column_type(stmt, i)) {
case SQLITE_INTEGER: {
// https://github.com/oven-sh/bun/issues/1536
result->putDirect(vm, name, jsNumberFromSQLite(stmt, i), 0);
break;
}
case SQLITE_FLOAT: {
result->putDirect(vm, name, jsDoubleNumber(sqlite3_column_double(stmt, i)), 0);
break;
}
// > Note that the SQLITE_TEXT constant was also used in SQLite version
// > 2 for a completely different meaning. Software that links against
// > both SQLite version 2 and SQLite version 3 should use SQLITE3_TEXT,
// > not SQLITE_TEXT.
case SQLITE3_TEXT: {
size_t len = sqlite3_column_bytes(stmt, i);
const unsigned char* text = len > 0 ? sqlite3_column_text(stmt, i) : nullptr;

if (len > 64) {
result->putDirect(vm, name, JSC::JSValue::decode(Bun__encoding__toStringUTF8(text, len, lexicalGlobalObject)), 0);
continue;
}

result->putDirect(vm, name, jsString(vm, WTF::String::fromUTF8({ text, len })), 0);
break;
}
case SQLITE_BLOB: {
size_t len = sqlite3_column_bytes(stmt, i);
const void* blob = len > 0 ? sqlite3_column_blob(stmt, i) : nullptr;
JSC::JSUint8Array* array = JSC::JSUint8Array::createUninitialized(lexicalGlobalObject, lexicalGlobalObject->m_typedArrayUint8.get(lexicalGlobalObject), len);

if (LIKELY(blob && len))
memcpy(array->vector(), blob, len);

result->putDirect(vm, name, array, 0);
break;
}
default: {
result->putDirect(vm, name, jsNull(), 0);
break;
}
}
}
}

Expand All @@ -1524,10 +1609,14 @@ static inline JSC::JSArray* constructResultRow(JSC::JSGlobalObject* lexicalGloba
JSC::JSArray* result = JSArray::create(vm, lexicalGlobalObject->arrayStructureForIndexingTypeDuringAllocation(ArrayWithContiguous), count);
auto* stmt = castedThis->stmt;

for (int i = 0; i < count; i++) {
for (int i = 0, j = 0; j < count; i++, j++) {
if (!castedThis->validColumns.get(i)) {
j -= 1;
continue;
}
JSValue value = toJS(vm, lexicalGlobalObject, stmt, i);
RETURN_IF_EXCEPTION(throwScope, nullptr);
result->putDirectIndex(lexicalGlobalObject, i, value);
result->putDirectIndex(lexicalGlobalObject, j, value);
}

return result;
Expand Down
81 changes: 77 additions & 4 deletions test/js/bun/sqlite/sqlite.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,77 @@ it.skipIf(
expect(db.prepare("SELECT TAN(0.25)").all()).toEqual([{ "TAN(0.25)": 0.25534192122103627 }]);
});

it("issue#6597", () => {
// better-sqlite3 returns the last value of duplicate fields
const db = new Database(":memory:");
db.run("CREATE TABLE Users (Id INTEGER PRIMARY KEY, Name VARCHAR(255), CreatedAt TIMESTAMP)");
db.run(
"CREATE TABLE Cars (Id INTEGER PRIMARY KEY, Driver INTEGER, CreatedAt TIMESTAMP, FOREIGN KEY (Driver) REFERENCES Users(Id))",
);
db.run('INSERT INTO Users (Id, Name, CreatedAt) VALUES (1, "Alice", "2022-01-01");');
db.run('INSERT INTO Cars (Id, Driver, CreatedAt) VALUES (2, 1, "2023-01-01");');
const result = db.prepare("SELECT * FROM Cars JOIN Users ON Driver=Users.Id").get();
expect(result).toStrictEqual({
Id: 1,
Driver: 1,
CreatedAt: "2022-01-01",
Name: "Alice",
});
db.close();
});

it("issue#6597 with many columns", () => {
// better-sqlite3 returns the last value of duplicate fields
const db = new Database(":memory:");
const count = 100;
const columns = Array.from({ length: count }, (_, i) => `col${i}`);
const values_foo = Array.from({ length: count }, (_, i) => `'foo${i}'`);
const values_bar = Array.from({ length: count }, (_, i) => `'bar${i}'`);
values_bar[0] = values_foo[0];
db.run(`CREATE TABLE foo (${columns.join(",")})`);
db.run(`CREATE TABLE bar (${columns.join(",")})`);
db.run(`INSERT INTO foo (${columns.join(",")}) VALUES (${values_foo.join(",")})`);
db.run(`INSERT INTO bar (${columns.join(",")}) VALUES (${values_bar.join(",")})`);
const result = db.prepare("SELECT * FROM foo JOIN bar ON foo.col0 = bar.col0").get();
expect(result.col0).toBe("foo0");
for (let i = 1; i < count; i++) {
expect(result[`col${i}`]).toBe(`bar${i}`);
}
db.close();
});

it("issue#7147", () => {
const db = new Database(":memory:");
db.exec("CREATE TABLE foos (foo_id INTEGER NOT NULL PRIMARY KEY, foo_a TEXT, foo_b TEXT)");
db.exec(
"CREATE TABLE bars (bar_id INTEGER NOT NULL PRIMARY KEY, foo_id INTEGER NOT NULL, bar_a INTEGER, bar_b INTEGER, FOREIGN KEY (foo_id) REFERENCES foos (foo_id))",
);
db.exec("INSERT INTO foos VALUES (1, 'foo_1', 'foo_2')");
db.exec("INSERT INTO bars VALUES (1, 1, 'bar_1', 'bar_2')");
db.exec("INSERT INTO bars VALUES (2, 1, 'baz_3', 'baz_4')");
const query = db.query("SELECT f.*, b.* FROM foos f JOIN bars b ON b.foo_id = f.foo_id");
const result = query.all();
expect(result).toStrictEqual([
{
foo_id: 1,
foo_a: "foo_1",
foo_b: "foo_2",
bar_id: 1,
bar_a: "bar_1",
bar_b: "bar_2",
},
{
foo_id: 1,
foo_a: "foo_1",
foo_b: "foo_2",
bar_id: 2,
bar_a: "baz_3",
bar_b: "baz_4",
},
]);
db.close();
});

it("should close with WAL enabled", () => {
const dir = tempDirWithFiles("sqlite-wal-test", { "empty.txt": "" });
const file = path.join(dir, "my.db");
Expand Down Expand Up @@ -812,11 +883,12 @@ it("close() should NOT throw an error if the database is in use", () => {

it("should dispose AND throw an error if the database is in use", () => {
expect(() => {
let prepared;
{
using db = new Database(":memory:");
db.exec("CREATE TABLE foo (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT)");
db.exec("INSERT INTO foo (name) VALUES ('foo')");
var prepared = db.prepare("SELECT * FROM foo");
prepared = db.prepare("SELECT * FROM foo");
}
}).toThrow("database is locked");
});
Expand All @@ -832,8 +904,8 @@ it("should dispose", () => {
});

it("can continue to use existing statements after database has been GC'd", async () => {
var called = false;
var registry = new FinalizationRegistry(() => {
let called = false;
const registry = new FinalizationRegistry(() => {
called = true;
});
function leakTheStatement() {
Expand Down Expand Up @@ -868,10 +940,11 @@ it("statements should be disposable", () => {

it("query should work if the cached statement was finalized", () => {
{
let prevQuery;
using db = new Database("mydb.sqlite");
{
using query = db.query("select 'Hello world' as message;");
var prevQuery = query;
prevQuery = query;
query.get();
}
{
Expand Down

0 comments on commit 24a411f

Please sign in to comment.