Skip to content

Commit

Permalink
feat: optionally use recordclass if available
Browse files Browse the repository at this point in the history
  • Loading branch information
cheahjs committed Feb 12, 2024
1 parent 8d84fea commit cc1d752
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 80 deletions.
16 changes: 14 additions & 2 deletions .github/workflows/premerge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12', 'pypy3.9', 'pypy3.10']
os: [ubuntu-latest, windows-latest]
force-stdlib: [true, false]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand All @@ -27,7 +28,18 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dev dependencies
run: |
pip install -r requirements-dev.txt
- name: Run tests
pip install '.[tests]'
- name: Install performance dependencies
if: matrix.force-stdlib == 'false'
run: |
pip install '.[performance]'
- name: Run tests (with performance dependencies)
if: matrix.force-stdlib == 'false'
run: |
python -m unittest -v
- name: Run tests (stdlib only)
if: matrix.force-stdlib == 'true'
env:
FORCE_STDLIB_ONLY: 1
run: |
python -m unittest -v
232 changes: 168 additions & 64 deletions palworld_save_tools/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,81 +2,185 @@
import math
import os
import struct
import sys
import uuid
from typing import Any, Callable, Optional, Sequence, Union
from recordclass import dataobject, as_dataclass

# Alias stdlib types to avoid name conflicts
_float = float
_bytes = bytes

@as_dataclass(hashable=True, fast_new=True)
class UUID:
raw_bytes: bytes
"""Wrapper around uuid.UUID to delay evaluation of UUIDs until necessary"""

@staticmethod
def from_str(s: str) -> "UUID":
b = uuid.UUID(s).bytes
return UUID(
bytes(
[
b[0x3],
b[0x2],
b[0x1],
b[0x0],
b[0x7],
b[0x6],
b[0x5],
b[0x4],
b[0xB],
b[0xA],
b[0x9],
b[0x8],
b[0xF],
b[0xE],
b[0xD],
b[0xC],
]
try:
from recordclass import as_dataclass
except ImportError:
pass

if os.getenv("FORCE_STDLIB_ONLY") or "recordclass" not in sys.modules:
if os.getenv("DEBUG"):
print("Using stdlib-compatible UUID class")

class UUID:
"""Wrapper around uuid.UUID to delay evaluation of UUIDs until necessary"""

__slots__ = ("raw_bytes", "parsed_uuid", "parsed_str")
raw_bytes: bytes
parsed_uuid: Optional[uuid.UUID]
parsed_str: Optional[str]

def __init__(self, raw_bytes: bytes) -> None:
self.raw_bytes = raw_bytes
self.parsed_uuid = None
self.parsed_str = None

@staticmethod
def from_str(s: str) -> "UUID":
b = uuid.UUID(s).bytes
return UUID(
bytes(
[
b[0x3],
b[0x2],
b[0x1],
b[0x0],
b[0x7],
b[0x6],
b[0x5],
b[0x4],
b[0xB],
b[0xA],
b[0x9],
b[0x8],
b[0xF],
b[0xE],
b[0xD],
b[0xC],
]
)
)
)

def __str__(self) -> str:
b = self.raw_bytes
return "%08x-%04x-%04x-%04x-%04x%08x" % ((b[3] << 24) | (b[2] << 16) | (b[1] << 8) | (b[0]),
(b[7] << 8) | (b[6]), (b[5] << 8) | (b[4]), (b[0xB] << 8) | (b[0xA]),
(b[9] << 8) | (b[8]),
(b[0xF] << 24) | (b[0xE] << 16) | (b[0xD] << 8) | (b[0xC]))

def UUID(self) -> uuid.UUID:
b = self.raw_bytes
uuid_int = (
b[0xC]
+ (b[0xD] << 8)
+ (b[0xE] << 16)
+ (b[0xF] << 24)
+ (b[0x8] << 32)
+ (b[0x9] << 40)
+ (b[0xA] << 48)
+ (b[0xB] << 56)
+ (b[0x4] << 64)
+ (b[0x5] << 72)
+ (b[0x6] << 80)
+ (b[0x7] << 88)
+ (b[0x0] << 96)
+ (b[0x1] << 104)
+ (b[0x2] << 112)
+ (b[0x3] << 120)
)
return uuid.UUID(int=uuid_int)
def __str__(self) -> str:
if not self.parsed_str:
b = self.raw_bytes
self.parsed_str = "%08x-%04x-%04x-%04x-%04x%08x" % (
(b[3] << 24) | (b[2] << 16) | (b[1] << 8) | (b[0]),
(b[7] << 8) | (b[6]),
(b[5] << 8) | (b[4]),
(b[0xB] << 8) | (b[0xA]),
(b[9] << 8) | (b[8]),
(b[0xF] << 24) | (b[0xE] << 16) | (b[0xD] << 8) | (b[0xC]),
)
return self.parsed_str

def UUID(self) -> uuid.UUID:
if not self.parsed_uuid:
b = self.raw_bytes
uuid_int = (
b[0xC]
+ (b[0xD] << 8)
+ (b[0xE] << 16)
+ (b[0xF] << 24)
+ (b[0x8] << 32)
+ (b[0x9] << 40)
+ (b[0xA] << 48)
+ (b[0xB] << 56)
+ (b[0x4] << 64)
+ (b[0x5] << 72)
+ (b[0x6] << 80)
+ (b[0x7] << 88)
+ (b[0x0] << 96)
+ (b[0x1] << 104)
+ (b[0x2] << 112)
+ (b[0x3] << 120)
)
self.parsed_uuid = uuid.UUID(int=uuid_int)
return self.parsed_uuid

def __eq__(self, __value: object) -> bool:
if isinstance(__value, UUID):
return self.raw_bytes == __value.raw_bytes
return str(self) == str(__value)

def __repr__(self) -> str:
return "%s.UUID('%s')" % (self.__module__, str(self))

def __hash__(self) -> int:
return hash(str(self))

else:
if os.getenv("DEBUG"):
print("Using recordclass-based UUID class")

@as_dataclass(hashable=True, fast_new=True)
class UUID: # type: ignore[no-redef]
raw_bytes: bytes
"""Wrapper around uuid.UUID to delay evaluation of UUIDs until necessary"""

@staticmethod
def from_str(s: str) -> "UUID":
b = uuid.UUID(s).bytes
return UUID(
bytes(
[
b[0x3],
b[0x2],
b[0x1],
b[0x0],
b[0x7],
b[0x6],
b[0x5],
b[0x4],
b[0xB],
b[0xA],
b[0x9],
b[0x8],
b[0xF],
b[0xE],
b[0xD],
b[0xC],
]
)
)

def __str__(self) -> str:
b = self.raw_bytes
return "%08x-%04x-%04x-%04x-%04x%08x" % (
(b[3] << 24) | (b[2] << 16) | (b[1] << 8) | (b[0]),
(b[7] << 8) | (b[6]),
(b[5] << 8) | (b[4]),
(b[0xB] << 8) | (b[0xA]),
(b[9] << 8) | (b[8]),
(b[0xF] << 24) | (b[0xE] << 16) | (b[0xD] << 8) | (b[0xC]),
)

def UUID(self) -> uuid.UUID:
b = self.raw_bytes
uuid_int = (
b[0xC]
+ (b[0xD] << 8)
+ (b[0xE] << 16)
+ (b[0xF] << 24)
+ (b[0x8] << 32)
+ (b[0x9] << 40)
+ (b[0xA] << 48)
+ (b[0xB] << 56)
+ (b[0x4] << 64)
+ (b[0x5] << 72)
+ (b[0x6] << 80)
+ (b[0x7] << 88)
+ (b[0x0] << 96)
+ (b[0x1] << 104)
+ (b[0x2] << 112)
+ (b[0x3] << 120)
)
return uuid.UUID(int=uuid_int)

def __eq__(self, __value: object) -> bool:
if isinstance(__value, UUID):
return self.raw_bytes == __value.raw_bytes
return str(self) == str(__value)
def __eq__(self, __value: object) -> bool:
if isinstance(__value, UUID):
return self.raw_bytes == __value.raw_bytes
return str(self) == str(__value)

def __repr__(self) -> str:
return "%s.UUID('%s')" % (self.__module__, str(self))
def __repr__(self) -> str:
return "%s.UUID('%s')" % (self.__module__, str(self))


# Specify a type for JSON-serializable objects
Expand Down
27 changes: 18 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,34 @@ build-backend = "hatchling.build"

[project]
name = "palworld-save-tools"
authors = [
{ name="Jun Siang Cheah", email="me@jscheah.me" },
]
authors = [{ name = "Jun Siang Cheah", email = "me@jscheah.me" }]
description = "Tools for converting Palworld .sav files to JSON and back"
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dynamic = ["version"]
dependencies = [
"recordclass"
]

[project.urls]
Homepage = "https://github.com/cheahjs/palworld-save-tools"
Issues = "https://github.com/cheahjs/palworld-save-tools/issues"

[project.scripts]
palworld-save-tools = "palworld_save_tools.commands.convert:main"

[project.optional-dependencies]
# These are dependencies only for tests
# Default usage of the library must not rely on any external dependencies!
tests = [
"parameterized==0.9.0",
"mypy==1.8.0"
]
# Additional dependencies to provide more performant implementations
performance = ["recordclass"]

[[tool.mypy.overrides]]
module = ["recordclass", "parameterized"]
ignore_missing_imports = true
4 changes: 0 additions & 4 deletions requirements-dev.txt

This file was deleted.

33 changes: 32 additions & 1 deletion tests/test_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_packed_vector_roundtrip(self, x, y, z):
self.assertEqual(y, y_e)
self.assertEqual(z, z_e)

def test_uuid_wrapper(self):
def test_uuid_wrapper_matches_stdlib(self):
test_uuid = "c1b41f12-90d3-491f-be71-b34e8e0deb5a"
expected = uuid.UUID(test_uuid)
b = expected.bytes
Expand All @@ -62,3 +62,34 @@ def test_uuid_wrapper(self):
)
wrapper = UUID(ue_bytes)
self.assertEqual(str(expected), str(wrapper))

def test_uuid_wrapper_can_be_used_as_dict_key(self):
test_uuid = "c1b41f12-90d3-491f-be71-b34e8e0deb5a"
wrapper = UUID.from_str(test_uuid)
d = {wrapper: "test"}
self.assertEqual("test", d[wrapper])

def test_uuid_wrapper_can_be_used_as_set_member(self):
test_uuid = "c1b41f12-90d3-491f-be71-b34e8e0deb5a"
wrapper = UUID.from_str(test_uuid)
s = {wrapper}
self.assertEqual(1, len(s))
self.assertTrue(wrapper in s)

def test_uuid_wrapper_equality(self):
test_uuid = "c1b41f12-90d3-491f-be71-b34e8e0deb5a"
wrapper = UUID.from_str(test_uuid)
wrapper2 = UUID.from_str(test_uuid)
self.assertEqual(wrapper, wrapper2)

def test_uuid_wrapper_inequality(self):
test_uuid = "c1b41f12-90d3-491f-be71-b34e8e0deb5a"
wrapper = UUID.from_str(test_uuid)
wrapper2 = UUID.from_str("c1b41f12-90d3-491f-be71-b34e8e0deb5b")
self.assertNotEqual(wrapper, wrapper2)

def test_uuid_wrapper_hash(self):
test_uuid = "c1b41f12-90d3-491f-be71-b34e8e0deb5a"
wrapper = UUID.from_str(test_uuid)
wrapper2 = UUID.from_str(test_uuid)
self.assertEqual(hash(wrapper), hash(wrapper2))

0 comments on commit cc1d752

Please sign in to comment.