From 24a411f9044536f8d4d8c64e17ddf2ee8bd031ed Mon Sep 17 00:00:00 2001 From: Georgijs <48869301+gvilums@users.noreply.github.com> Date: Mon, 15 Apr 2024 14:02:28 -0700 Subject: [PATCH] Correctly handle duplicate column names in sqlite joins (#10285) * add tests * working * cleanup * fix compile * fix naming and comment * fix lints in test * apply suggested fixes --- packages/bun-usockets/src/crypto/openssl.c | 2 - src/bun.js/bindings/sqlite/JSSQLStatement.cpp | 109 ++++++++++++++++-- test/js/bun/sqlite/sqlite.test.js | 81 ++++++++++++- 3 files changed, 176 insertions(+), 16 deletions(-) diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index e5514680b73f5..9f502ecb6044b 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -257,8 +257,6 @@ void us_internal_on_ssl_handshake( struct us_internal_ssl_socket_t * us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { - struct us_internal_ssl_socket_context_t *context = - (struct us_internal_ssl_socket_context_t *)us_socket_context(0, &s->s); if (s->handshake_state != HANDSHAKE_COMPLETED) { // if we have some pending handshake we cancel it and try to check the diff --git a/src/bun.js/bindings/sqlite/JSSQLStatement.cpp b/src/bun.js/bindings/sqlite/JSSQLStatement.cpp index 4fb9f9d950bac..6c8912fdc7c2c 100644 --- a/src/bun.js/bindings/sqlite/JSSQLStatement.cpp +++ b/src/bun.js/bindings/sqlite/JSSQLStatement.cpp @@ -34,6 +34,8 @@ #include #include "BunBuiltinNames.h" #include "sqlite3_error_codes.h" +#include "wtf/BitVector.h" +#include "wtf/Vector.h" #include /* ******************************************************************************** */ @@ -322,6 +324,9 @@ class JSSQLStatement : public JSC::JSDestructibleObject { VersionSqlite3* version_db; uint64_t version; bool hasExecuted = false; + // Tracks which columns are valid in the current result set. Used to handle duplicate column names. + // The bit at index i is set if the column at index i is valid. + WTF::BitVector validColumns; std::unique_ptr columnNames; mutable JSC::WriteBarrier _prototype; mutable JSC::WriteBarrier _structure; @@ -462,6 +467,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ castedThis->columnNames->propertyNameMode(), castedThis->columnNames->privateSymbolMode())); } + castedThis->validColumns.clearAll(); castedThis->update_version(); JSC::VM& vm = lexicalGlobalObject->vm(); @@ -484,7 +490,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ auto columnNames = castedThis->columnNames.get(); bool anyHoles = false; - for (int i = 0; i < count; i++) { + for (int i = count - 1; i >= 0; i--) { const char* name = sqlite3_column_name(stmt, i); if (name == nullptr) { @@ -498,7 +504,18 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ break; } - columnNames->add(Identifier::fromString(vm, WTF::String::fromUTF8({ name, len }))); + // When joining multiple tables, the same column names can appear multiple times + // columnNames de-dupes property names internally + // We can't have two properties with the same name, so we use validColumns to track this. + auto preCount = columnNames->size(); + columnNames->add( + Identifier::fromString(vm, WTF::String::fromUTF8({name, len})) + ); + auto curCount = columnNames->size(); + + if (preCount != curCount) { + castedThis->validColumns.set(i); + } } if (LIKELY(!anyHoles)) { @@ -506,6 +523,10 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ Structure* structure = globalObject.structureCache().emptyObjectStructureForPrototype(&globalObject, globalObject.objectPrototype(), columnNames->size()); vm.writeBarrier(castedThis, structure); + // We iterated over the columns in reverse order so we need to reverse the columnNames here + // Importantly we reverse before adding the properties to the structure to ensure that index accesses + // later refer to the correct property. + columnNames->data()->propertyNameVector().reverse(); for (const auto& propertyName : *columnNames) { structure = Structure::addPropertyTransition(vm, structure, propertyName, 0, offset); } @@ -520,6 +541,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ castedThis->columnNames->vm(), castedThis->columnNames->propertyNameMode(), castedThis->columnNames->privateSymbolMode())); + castedThis->validColumns.clearAll(); } } @@ -531,7 +553,7 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ // see https://github.com/oven-sh/bun/issues/987 JSC::JSObject* object = JSC::constructEmptyObject(lexicalGlobalObject, lexicalGlobalObject->objectPrototype(), std::min(static_cast(count), JSFinalObject::maxInlineCapacity)); - for (int i = 0; i < count; i++) { + for (int i = count - 1; i >= 0; i--) { const char* name = sqlite3_column_name(stmt, i); if (name == nullptr) @@ -562,9 +584,18 @@ static void initializeColumnNames(JSC::JSGlobalObject* lexicalGlobalObject, JSSQ } } - object->putDirect(vm, key, primitive, 0); + auto preCount = castedThis->columnNames->size(); castedThis->columnNames->add(key); + auto curCount = castedThis->columnNames->size(); + + // only put the property if it's not a duplicate + if (preCount != curCount) { + castedThis->validColumns.set(i); + object->putDirect(vm, key, primitive, 0); + } } + // We iterated over the columns in reverse order so we need to reverse the columnNames here + castedThis->columnNames->data()->propertyNameVector().reverse(); castedThis->_prototype.set(vm, castedThis, object); } @@ -1495,8 +1526,15 @@ static inline JSC::JSValue constructResultObject(JSC::JSGlobalObject* lexicalGlo if (auto* structure = castedThis->_structure.get()) { result = JSC::constructEmptyObject(vm, structure); - for (unsigned int i = 0; i < count; i++) { - result->putDirectOffset(vm, i, toJS(vm, lexicalGlobalObject, stmt, i)); + // i: the index of columns returned from SQLite + // j: the index of object property + for (int i = 0, j = 0; j < count; i++, j++) { + if (!castedThis->validColumns.get(i)) { + // this column is duplicate, skip + j -= 1; + continue; + } + result->putDirectOffset(vm, j, toJS(vm, lexicalGlobalObject, stmt, i)); } } else { @@ -1506,9 +1544,56 @@ static inline JSC::JSValue constructResultObject(JSC::JSGlobalObject* lexicalGlo result = JSC::JSFinalObject::create(vm, JSC::JSFinalObject::createStructure(vm, lexicalGlobalObject, lexicalGlobalObject->objectPrototype(), JSFinalObject::maxInlineCapacity)); } - for (int i = 0; i < count; i++) { - const auto& name = columnNames[i]; + for (int i = 0, j = 0; j < count; i++, j++) { + if (!castedThis->validColumns.get(i)) { + j -= 1; + continue; + } + auto name = columnNames[j]; result->putDirect(vm, name, toJS(vm, lexicalGlobalObject, stmt, i), 0); + + switch (sqlite3_column_type(stmt, i)) { + case SQLITE_INTEGER: { + // https://github.com/oven-sh/bun/issues/1536 + result->putDirect(vm, name, jsNumberFromSQLite(stmt, i), 0); + break; + } + case SQLITE_FLOAT: { + result->putDirect(vm, name, jsDoubleNumber(sqlite3_column_double(stmt, i)), 0); + break; + } + // > Note that the SQLITE_TEXT constant was also used in SQLite version + // > 2 for a completely different meaning. Software that links against + // > both SQLite version 2 and SQLite version 3 should use SQLITE3_TEXT, + // > not SQLITE_TEXT. + case SQLITE3_TEXT: { + size_t len = sqlite3_column_bytes(stmt, i); + const unsigned char* text = len > 0 ? sqlite3_column_text(stmt, i) : nullptr; + + if (len > 64) { + result->putDirect(vm, name, JSC::JSValue::decode(Bun__encoding__toStringUTF8(text, len, lexicalGlobalObject)), 0); + continue; + } + + result->putDirect(vm, name, jsString(vm, WTF::String::fromUTF8({ text, len })), 0); + break; + } + case SQLITE_BLOB: { + size_t len = sqlite3_column_bytes(stmt, i); + const void* blob = len > 0 ? sqlite3_column_blob(stmt, i) : nullptr; + JSC::JSUint8Array* array = JSC::JSUint8Array::createUninitialized(lexicalGlobalObject, lexicalGlobalObject->m_typedArrayUint8.get(lexicalGlobalObject), len); + + if (LIKELY(blob && len)) + memcpy(array->vector(), blob, len); + + result->putDirect(vm, name, array, 0); + break; + } + default: { + result->putDirect(vm, name, jsNull(), 0); + break; + } + } } } @@ -1524,10 +1609,14 @@ static inline JSC::JSArray* constructResultRow(JSC::JSGlobalObject* lexicalGloba JSC::JSArray* result = JSArray::create(vm, lexicalGlobalObject->arrayStructureForIndexingTypeDuringAllocation(ArrayWithContiguous), count); auto* stmt = castedThis->stmt; - for (int i = 0; i < count; i++) { + for (int i = 0, j = 0; j < count; i++, j++) { + if (!castedThis->validColumns.get(i)) { + j -= 1; + continue; + } JSValue value = toJS(vm, lexicalGlobalObject, stmt, i); RETURN_IF_EXCEPTION(throwScope, nullptr); - result->putDirectIndex(lexicalGlobalObject, i, value); + result->putDirectIndex(lexicalGlobalObject, j, value); } return result; diff --git a/test/js/bun/sqlite/sqlite.test.js b/test/js/bun/sqlite/sqlite.test.js index f75c2f03153c6..60a9d01b1cdee 100644 --- a/test/js/bun/sqlite/sqlite.test.js +++ b/test/js/bun/sqlite/sqlite.test.js @@ -778,6 +778,77 @@ it.skipIf( expect(db.prepare("SELECT TAN(0.25)").all()).toEqual([{ "TAN(0.25)": 0.25534192122103627 }]); }); +it("issue#6597", () => { + // better-sqlite3 returns the last value of duplicate fields + const db = new Database(":memory:"); + db.run("CREATE TABLE Users (Id INTEGER PRIMARY KEY, Name VARCHAR(255), CreatedAt TIMESTAMP)"); + db.run( + "CREATE TABLE Cars (Id INTEGER PRIMARY KEY, Driver INTEGER, CreatedAt TIMESTAMP, FOREIGN KEY (Driver) REFERENCES Users(Id))", + ); + db.run('INSERT INTO Users (Id, Name, CreatedAt) VALUES (1, "Alice", "2022-01-01");'); + db.run('INSERT INTO Cars (Id, Driver, CreatedAt) VALUES (2, 1, "2023-01-01");'); + const result = db.prepare("SELECT * FROM Cars JOIN Users ON Driver=Users.Id").get(); + expect(result).toStrictEqual({ + Id: 1, + Driver: 1, + CreatedAt: "2022-01-01", + Name: "Alice", + }); + db.close(); +}); + +it("issue#6597 with many columns", () => { + // better-sqlite3 returns the last value of duplicate fields + const db = new Database(":memory:"); + const count = 100; + const columns = Array.from({ length: count }, (_, i) => `col${i}`); + const values_foo = Array.from({ length: count }, (_, i) => `'foo${i}'`); + const values_bar = Array.from({ length: count }, (_, i) => `'bar${i}'`); + values_bar[0] = values_foo[0]; + db.run(`CREATE TABLE foo (${columns.join(",")})`); + db.run(`CREATE TABLE bar (${columns.join(",")})`); + db.run(`INSERT INTO foo (${columns.join(",")}) VALUES (${values_foo.join(",")})`); + db.run(`INSERT INTO bar (${columns.join(",")}) VALUES (${values_bar.join(",")})`); + const result = db.prepare("SELECT * FROM foo JOIN bar ON foo.col0 = bar.col0").get(); + expect(result.col0).toBe("foo0"); + for (let i = 1; i < count; i++) { + expect(result[`col${i}`]).toBe(`bar${i}`); + } + db.close(); +}); + +it("issue#7147", () => { + const db = new Database(":memory:"); + db.exec("CREATE TABLE foos (foo_id INTEGER NOT NULL PRIMARY KEY, foo_a TEXT, foo_b TEXT)"); + db.exec( + "CREATE TABLE bars (bar_id INTEGER NOT NULL PRIMARY KEY, foo_id INTEGER NOT NULL, bar_a INTEGER, bar_b INTEGER, FOREIGN KEY (foo_id) REFERENCES foos (foo_id))", + ); + db.exec("INSERT INTO foos VALUES (1, 'foo_1', 'foo_2')"); + db.exec("INSERT INTO bars VALUES (1, 1, 'bar_1', 'bar_2')"); + db.exec("INSERT INTO bars VALUES (2, 1, 'baz_3', 'baz_4')"); + const query = db.query("SELECT f.*, b.* FROM foos f JOIN bars b ON b.foo_id = f.foo_id"); + const result = query.all(); + expect(result).toStrictEqual([ + { + foo_id: 1, + foo_a: "foo_1", + foo_b: "foo_2", + bar_id: 1, + bar_a: "bar_1", + bar_b: "bar_2", + }, + { + foo_id: 1, + foo_a: "foo_1", + foo_b: "foo_2", + bar_id: 2, + bar_a: "baz_3", + bar_b: "baz_4", + }, + ]); + db.close(); +}); + it("should close with WAL enabled", () => { const dir = tempDirWithFiles("sqlite-wal-test", { "empty.txt": "" }); const file = path.join(dir, "my.db"); @@ -812,11 +883,12 @@ it("close() should NOT throw an error if the database is in use", () => { it("should dispose AND throw an error if the database is in use", () => { expect(() => { + let prepared; { using db = new Database(":memory:"); db.exec("CREATE TABLE foo (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT)"); db.exec("INSERT INTO foo (name) VALUES ('foo')"); - var prepared = db.prepare("SELECT * FROM foo"); + prepared = db.prepare("SELECT * FROM foo"); } }).toThrow("database is locked"); }); @@ -832,8 +904,8 @@ it("should dispose", () => { }); it("can continue to use existing statements after database has been GC'd", async () => { - var called = false; - var registry = new FinalizationRegistry(() => { + let called = false; + const registry = new FinalizationRegistry(() => { called = true; }); function leakTheStatement() { @@ -868,10 +940,11 @@ it("statements should be disposable", () => { it("query should work if the cached statement was finalized", () => { { + let prevQuery; using db = new Database("mydb.sqlite"); { using query = db.query("select 'Hello world' as message;"); - var prevQuery = query; + prevQuery = query; query.get(); } {