Skip to content

Commit

Permalink
[wpiutil] Fix DynamicStruct string handling (#6253)
Browse files Browse the repository at this point in the history
Dynamic structs had a few major issues.

In C++, if the string was the last definition in the schema, attempting to set a string would trigger an assertion. This has been fixed

Setting a string value could truncate the string actually stored in the struct, if the definition was shorter than the string to set.
There was no way to detect if this case occurred. The set string function now returns a bool if the string was fully written or not.

Reading a string that had a value shorter than the schema definition would result in embedded trailing nulls in the string. This would make comparing string equality basically impossible, as those embedded nulls count for the length of the string.

The above truncating didn't take into account UTF8 code points. This means a truncation could happen in the middle of a unicode character. Depending on the language this had different behavior, but unpaired code points are problematic to detect in any case. On the decoding side, detect if a split UTF8 code point has occurred by the writer, and if so just ignore it and treat it as not part of the string. Doing this on the receive side means a newer receive side is all that is needed to fix this, which is generally a better option then requiring all senders to update.

Actual DynamicStruct instances have 0 units tests for them. Added a bunch of unit tests around strings to ensure things work properly.
  • Loading branch information
ThadHouse committed Jan 20, 2024
1 parent 4b15c73 commit 0e5eb3f
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 13 deletions.
3 changes: 3 additions & 0 deletions wpiutil/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,7 @@ if(WITH_TESTS)

wpilib_add_test(wpiutil src/test/native/cpp)
target_link_libraries(wpiutil_test wpiutil gmock_main wpiutil_testlib)
if(MSVC)
target_compile_options(wpiutil_test PRIVATE /utf-8)
endif()
endif()
57 changes: 55 additions & 2 deletions wpiutil/src/main/java/edu/wpi/first/util/struct/DynamicStruct.java
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ public void setDoubleField(StructFieldDescriptor field, double value) {
* @throws IllegalArgumentException if field is not a member of this struct
* @throws IllegalStateException if struct descriptor is invalid
*/
@SuppressWarnings({"PMD.CollapsibleIfStatements", "PMD.AvoidDeeplyNestedIfStmts"})
public String getStringField(StructFieldDescriptor field) {
if (field.getType() != StructFieldType.kChar) {
throw new UnsupportedOperationException("field is not char type");
Expand All @@ -390,19 +391,69 @@ public String getStringField(StructFieldDescriptor field) {
}
byte[] bytes = new byte[field.m_arraySize];
m_data.position(field.m_offset).get(bytes, 0, field.m_arraySize);
return new String(bytes, StandardCharsets.UTF_8);
// Find last non zero character
int stringLength = bytes.length;
for (; stringLength > 0; stringLength--) {
if (bytes[stringLength - 1] != 0) {
break;
}
}
// If string is all zeroes, its empty and return an empty string.
if (stringLength == 0) {
return "";
}
// Check if the end of the string is in the middle of a continuation byte or
// not.
if ((bytes[stringLength - 1] & 0x80) != 0) {
// This is a UTF8 continuation byte. Make sure its valid.
// Walk back until initial byte is found
int utf8StartByte = stringLength;
for (; utf8StartByte > 0; utf8StartByte--) {
if ((bytes[utf8StartByte - 1] & 0x40) != 0) {
// Having 2nd bit set means start byte
break;
}
}
if (utf8StartByte == 0) {
// This case means string only contains continuation bytes
return "";
}
utf8StartByte--;
// Check if its a 2, 3, or 4 byte
byte checkByte = bytes[utf8StartByte];
if ((checkByte & 0xE0) == 0xC0) {
// 2 byte, need 1 more byte
if (utf8StartByte != stringLength - 2) {
stringLength = utf8StartByte;
}
} else if ((checkByte & 0xF0) == 0xE0) {
// 3 byte, need 2 more bytes
if (utf8StartByte != stringLength - 3) {
stringLength = utf8StartByte;
}
} else if ((checkByte & 0xF8) == 0xF0) {
// 4 byte, need 3 more bytes
if (utf8StartByte != stringLength - 4) {
stringLength = utf8StartByte;
}
}
// If we get here, the string is either completely garbage or fine.
}

return new String(bytes, 0, stringLength, StandardCharsets.UTF_8);
}

/**
* Sets the value of a character or character array field.
*
* @param field field descriptor
* @param value field value
* @return true if the full value fit in the struct, false if truncated
* @throws UnsupportedOperationException if field is not char type
* @throws IllegalArgumentException if field is not a member of this struct
* @throws IllegalStateException if struct descriptor is invalid
*/
public void setStringField(StructFieldDescriptor field, String value) {
public boolean setStringField(StructFieldDescriptor field, String value) {
if (field.getType() != StructFieldType.kChar) {
throw new UnsupportedOperationException("field is not char type");
}
Expand All @@ -414,10 +465,12 @@ public void setStringField(StructFieldDescriptor field, String value) {
}
ByteBuffer bb = StandardCharsets.UTF_8.encode(value);
int len = Math.min(bb.remaining(), field.m_arraySize);
boolean copiedFull = len == bb.remaining();
m_data.position(field.m_offset).put(bb.limit(len));
for (int i = len; i < field.m_arraySize; i++) {
m_data.put((byte) 0);
}
return copiedFull;
}

/**
Expand Down
65 changes: 62 additions & 3 deletions wpiutil/src/main/native/cpp/struct/DynamicStruct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,75 @@ void MutableDynamicStruct::SetData(std::span<const uint8_t> data) {
std::copy(data.begin(), data.begin() + m_desc->GetSize(), m_data.begin());
}

void MutableDynamicStruct::SetStringField(const StructFieldDescriptor* field,
std::string_view DynamicStruct::GetStringField(
const StructFieldDescriptor* field) const {
assert(field->m_type == StructFieldType::kChar);
assert(field->m_parent == m_desc);
assert(m_desc->IsValid());
// Find last non zero character
size_t stringLength;
for (stringLength = field->m_arraySize; stringLength > 0; stringLength--) {
if (m_data[field->m_offset + stringLength - 1] != 0) {
break;
}
}
// If string is all zeroes, its empty and return an empty string.
if (stringLength == 0) {
return "";
}
// Check if the end of the string is in the middle of a continuation byte or
// not.
if ((m_data[field->m_offset + stringLength - 1] & 0x80) != 0) {
// This is a UTF8 continuation byte. Make sure its valid.
// Walk back until initial byte is found
size_t utf8StartByte = stringLength;
for (; utf8StartByte > 0; utf8StartByte--) {
if ((m_data[field->m_offset + utf8StartByte - 1] & 0x40) != 0) {
// Having 2nd bit set means start byte
break;
}
}
if (utf8StartByte == 0) {
// This case means string only contains continuation bytes
return "";
}
utf8StartByte--;
// Check if its a 2, 3, or 4 byte
uint8_t checkByte = m_data[field->m_offset + utf8StartByte];
if ((checkByte & 0xE0) == 0xC0) {
// 2 byte, need 1 more byte
if (utf8StartByte != stringLength - 2) {
stringLength = utf8StartByte;
}
} else if ((checkByte & 0xF0) == 0xE0) {
// 3 byte, need 2 more bytes
if (utf8StartByte != stringLength - 3) {
stringLength = utf8StartByte;
}
} else if ((checkByte & 0xF8) == 0xF0) {
// 4 byte, need 3 more bytes
if (utf8StartByte != stringLength - 4) {
stringLength = utf8StartByte;
}
}
// If we get here, the string is either completely garbage or fine.
}
return {reinterpret_cast<const char*>(&m_data[field->m_offset]),
stringLength};
}

bool MutableDynamicStruct::SetStringField(const StructFieldDescriptor* field,
std::string_view value) {
assert(field->m_type == StructFieldType::kChar);
assert(field->m_parent == m_desc);
assert(m_desc->IsValid());
size_t len = (std::min)(field->m_arraySize, value.size());
bool copiedFull = len == value.size();
std::copy(value.begin(), value.begin() + len,
reinterpret_cast<char*>(&m_data[field->m_offset]));
std::fill(&m_data[field->m_offset + len],
&m_data[field->m_offset + field->m_arraySize], 0);
auto toFill = m_data.subspan(field->m_offset + len, field->m_arraySize - len);
std::fill(toFill.begin(), toFill.end(), 0);
return copiedFull;
}

void MutableDynamicStruct::SetStructField(const StructFieldDescriptor* field,
Expand Down
11 changes: 3 additions & 8 deletions wpiutil/src/main/native/include/wpi/struct/DynamicStruct.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,13 +472,7 @@ class DynamicStruct {
* @param field field descriptor
* @return field value
*/
std::string_view GetStringField(const StructFieldDescriptor* field) const {
assert(field->m_type == StructFieldType::kChar);
assert(field->m_parent == m_desc);
assert(m_desc->IsValid());
return {reinterpret_cast<const char*>(&m_data[field->m_offset]),
field->m_arraySize};
}
std::string_view GetStringField(const StructFieldDescriptor* field) const;

/**
* Gets the value of a struct field.
Expand Down Expand Up @@ -610,8 +604,9 @@ class MutableDynamicStruct : public DynamicStruct {
*
* @param field field descriptor
* @param value field value
* @return true if the full value fit in the struct, false if truncated
*/
void SetStringField(const StructFieldDescriptor* field,
bool SetStringField(const StructFieldDescriptor* field,
std::string_view value);

/**
Expand Down
117 changes: 117 additions & 0 deletions wpiutil/src/test/java/edu/wpi/first/util/struct/DynamicStructTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

@SuppressWarnings("AvoidEscapedUnicodeCharacters")
class DynamicStructTest {
@SuppressWarnings("MemberName")
private StructDescriptorDatabase db;
Expand Down Expand Up @@ -387,4 +388,120 @@ void testStandardArray(
assertNotNull(field.getStruct());
}
}

@Test
void testStringAllZeros() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[32]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertEquals("", dynamic.getStringField(field));
}

@Test
void testStringRoundTrip() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[32]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertTrue(dynamic.setStringField(field, "abc"));
assertEquals("abc", dynamic.getStringField(field));
}

@Test
void testStringRoundTripEmbeddedNull() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[32]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertTrue(dynamic.setStringField(field, "ab\0c"));
assertEquals("ab\0c", dynamic.getStringField(field));
}

@Test
void testStringRoundTripStringTooLong() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[2]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertFalse(dynamic.setStringField(field, "abc"));
assertEquals("ab", dynamic.getStringField(field));
}

@Test
void testStringRoundTripPartial2ByteUtf8() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[2]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertFalse(dynamic.setStringField(field, "a\u0234"));
assertEquals("a", dynamic.getStringField(field));
}

@Test
void testStringRoundTrip2ByteUtf8() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[3]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertTrue(dynamic.setStringField(field, "a\u0234"));
assertEquals("a\u0234", dynamic.getStringField(field));
}

@Test
void testStringRoundTripPartial3ByteUtf8FirstByte() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[2]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertFalse(dynamic.setStringField(field, "a\u1234"));
assertEquals("a", dynamic.getStringField(field));
}

@Test
void testStringRoundTripPartial3ByteUtf8SecondByte() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[3]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertFalse(dynamic.setStringField(field, "a\u1234"));
assertEquals("a", dynamic.getStringField(field));
}

@Test
void testStringRoundTrip3ByteUtf8() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[4]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertTrue(dynamic.setStringField(field, "a\u1234"));
assertEquals("a\u1234", dynamic.getStringField(field));
}

@Test
void testStringRoundTripPartial4ByteUtf8FirstByte() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[2]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertFalse(dynamic.setStringField(field, "a\uD83D\uDC00"));
assertEquals("a", dynamic.getStringField(field));
}

@Test
void testStringRoundTripPartial4ByteUtf8SecondByte() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[3]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertFalse(dynamic.setStringField(field, "a\uD83D\uDC00"));
assertEquals("a", dynamic.getStringField(field));
}

@Test
void testStringRoundTripPartial4ByteUtf8ThirdByte() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[4]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertFalse(dynamic.setStringField(field, "a\uD83D\uDC00"));
assertEquals("a", dynamic.getStringField(field));
}

@Test
void testStringRoundTrip4ByteUtf8() {
var desc = assertDoesNotThrow(() -> db.add("test", "char a[5]"));
var dynamic = DynamicStruct.allocate(desc);
var field = desc.findFieldByName("a");
assertTrue(dynamic.setStringField(field, "a\uD83D\uDC00"));
assertEquals("a\uD83D\uDC00", dynamic.getStringField(field));
}
}
Loading

0 comments on commit 0e5eb3f

Please sign in to comment.