diff --git a/fitparse/base.py b/fitparse/base.py index 0e6abbe..f5a6234 100644 --- a/fitparse/base.py +++ b/fitparse/base.py @@ -1,14 +1,6 @@ -import io import os import struct -# Python 2 compat -try: - num_types = (int, float, long) - str = basestring -except NameError: - num_types = (int, float) - from fitparse.processors import FitFileDataProcessor from fitparse.profile import FIELD_TYPE_TIMESTAMP, MESSAGE_TYPES from fitparse.records import ( @@ -94,6 +86,7 @@ def _parse_file_header(self): # Initialize data self._accumulators = {} + self.data_size = 0 self._bytes_left = -1 self._complete = False self._compressed_ts_accumulator = 0 @@ -106,7 +99,7 @@ def _parse_file_header(self): raise FitHeaderError("Invalid .FIT File Header") # Larger fields are explicitly little endian from SDK - header_size, protocol_ver_enc, profile_ver_enc, data_size = self._read_struct('2BHI4x', data=header_data) + header_size, protocol_ver_enc, profile_ver_enc, self.data_size = self._read_struct('2BHI4x', data=header_data) # Decode the same way the SDK does self.protocol_version = float("%d.%d" % (protocol_ver_enc >> 4, protocol_ver_enc & ((1 << 4) - 1))) @@ -127,7 +120,7 @@ def _parse_file_header(self): self._read(extra_header_size - 2) # After we've consumed the header, set the bytes left to be read - self._bytes_left = data_size + self._bytes_left = self.data_size def _parse_message(self): # When done, calculate the CRC and return None @@ -239,7 +232,7 @@ def _parse_definition_message(self, header): def _parse_raw_values_from_data_message(self, def_mesg): # Go through mesg's field defs and read them raw_values = [] - for field_def in def_mesg.field_defs + def_mesg.dev_field_defs: + for field_def in def_mesg.all_field_defs(): base_type = field_def.base_type is_byte = base_type.name == 'byte' # Struct to read n base types (field def size / base type size) @@ -277,18 +270,6 @@ def _resolve_subfield(field, def_mesg, raw_values): return sub_field, field return field, None - def _apply_scale_offset(self, field, raw_value): - # Apply numeric transformations (scale+offset) - if isinstance(raw_value, tuple): - # Contains multiple values, apply transformations to all of them - return tuple(self._apply_scale_offset(field, x) for x in raw_value) - elif isinstance(raw_value, num_types): - if field.scale: - raw_value = float(raw_value) / field.scale - if field.offset: - raw_value = raw_value - field.offset - return raw_value - @staticmethod def _apply_compressed_accumulation(raw_value, accumulation, num_bits): max_value = (1 << num_bits) @@ -311,7 +292,7 @@ def _parse_data_message(self, header): # TODO: Maybe refactor this and make it simpler (or at least broken # up into sub-functions) - for field_def, raw_value in zip(def_mesg.field_defs + def_mesg.dev_field_defs, raw_values): + for field_def, raw_value in zip(def_mesg.all_field_defs(), raw_values): field, parent_field = field_def.field, None if field: field, parent_field = self._resolve_subfield(field, def_mesg, raw_values) @@ -332,7 +313,7 @@ def _parse_data_message(self, header): # Apply scale and offset from component, not from the dynamic field # as they may differ - cmp_raw_value = self._apply_scale_offset(component, cmp_raw_value) + cmp_raw_value = component.apply_scale_offset(cmp_raw_value) # Extract the component's dynamic field from def_mesg cmp_field = def_mesg.mesg_type.fields[component.def_num] @@ -354,7 +335,8 @@ def _parse_data_message(self, header): # TODO: Do we care about a base_type and a resolved field mismatch? # My hunch is we don't - value = self._apply_scale_offset(field, field.render(raw_value)) + value = field.render(raw_value) + value = field.apply_scale_offset(value) else: value = raw_value diff --git a/fitparse/encoder.py b/fitparse/encoder.py new file mode 100644 index 0000000..90cd981 --- /dev/null +++ b/fitparse/encoder.py @@ -0,0 +1,436 @@ +import os +import re +import struct + +import six + +from fitparse import FitFileDataProcessor, profile, utils +from .records import Crc, DataMessage, DefinitionMessage, MessageHeader, FieldData, FieldDefinition +from .utils import fileish_open, FitParseError + + +class FitFileEncoder(object): + def __init__(self, fileish, + protocol_version=1.0, profile_version=20.33, + data_processor=None): + """ + Create FIT encoder. + + :param fileish: file-ish object, + :param protocol_version: protocol version, change to 2.0 if you use developer fields. + :param profile_version: profile version. + :param data_processor: custom data processor. + """ + self.protocol_version = float(protocol_version) + self.profile_version = float(profile_version) + + self._processor = data_processor or FitFileDataProcessor() + self._file = fileish_open(fileish, 'wb') + self._byte_start = 0 + self._bytes_written = 0 + self._compressed_ts_accumulator = 0 + self._local_mesgs = {} + self.data_size = 0 + self.completed = False + self._crc = Crc() + + self._write_file_header_place() + + def __del__(self): + self.close() + + def __enter__(self): + return self + + def __exit__(self, *_): + self.close() + + def close(self): + self.finish() + if hasattr(self, "_file") and self._file and hasattr(self._file, "close"): + self._file.close() + self._file = None + + ########## + # Private low-level utility methods for writing of fit file + + def _write(self, data): + if not data: + return + self._file.write(data) + self._bytes_written += len(data) + self._crc.update(data) + + def _write_struct(self, data, fmt, endian='<'): + fmt_with_endian = endian + fmt + size = struct.calcsize(fmt_with_endian) + if size <= 0: + raise FitParseError('Invalid struct format: {}'.format(fmt_with_endian)) + # handle non iterable and iterable data + if utils.is_iterable(data): + packed = struct.pack(fmt_with_endian, *data) + else: + packed = struct.pack(fmt_with_endian, data) + self._write(packed) + return packed + + @staticmethod + def _check_number_bits(n, bits, errmsg): + if n & ~ bits != 0: + raise FitParseError('{}: too large: {}'.format(errmsg, n)) + + ########## + # Private data unparsing methods + @staticmethod + def _is_ts_field(field): + return field and field.def_num == profile.FIELD_TYPE_TIMESTAMP.def_num + + def _write_file_header_place(self): + """Write zeroes instead of header.""" + self._byte_start = self._file.tell() + self._write(b'\0' * 14) + self._bytes_written = 0 + + def _write_file_header(self): + # encode versions + protocol_major, protocol_minor = re.match(r'([\d]+)\.(\d+)', str(self.protocol_version)).groups() + protocol_ver_enc = (int(protocol_major) << 4) | int(protocol_minor) + profile_ver_enc = int(round(self.profile_version * 100)) + self.data_size = self._bytes_written + + self._file.seek(self._byte_start, os.SEEK_SET) + data = self._write_struct((14, protocol_ver_enc, profile_ver_enc, self.data_size, b'.FIT'), '2BHI4s') + crc = Crc(byte_arr=data) + self._write_struct(crc.value, Crc.FMT) + self._file.seek(0, os.SEEK_END) + + def _write_message_header(self, header): + data = 0 + if header.time_offset is not None: # compressed timestamp + self._check_number_bits(header.local_mesg_num, 0x3, 'Message header local_mesg_num') + self._check_number_bits(header.time_offset, 0x1f, 'Message header time_offset') + data = 0x80 # bit 7 + data |= header.local_mesg_num << 5 # bits 5-6 + data |= header.time_offset # bits 0-4 + else: + self._check_number_bits(header.local_mesg_num, 0xf, 'Message header local_mesg_num') + if header.is_definition: + data |= 0x40 # bit 6 + if header.is_developer_data: + data |= 0x20 # bit 5 + data |= header.local_mesg_num # bits 0 - 3 + self._write_struct(data, 'B') + + def _write_definition_message(self, def_mesg): + if not self._local_mesgs and def_mesg.name != 'file_id': + raise FitParseError('First message must be file_id') + + self._write_message_header(def_mesg.header) + # reserved and architecture bytes + endian = def_mesg.endian + data = int(endian == '>') + self._write_struct(data, 'xB') + # rest of header with endian awareness + data = (def_mesg.mesg_num, len(def_mesg.field_defs)) + self._write_struct(data, 'HB', endian=endian) + for field_def in def_mesg.field_defs: + data = (field_def.def_num, field_def.size, field_def.base_type.identifier) + self._write_struct(data, '3B', endian=endian) + if def_mesg.header.is_developer_data: + data = len(def_mesg.dev_field_defs) + self._write_struct(data, 'B', endian=endian) + for field_def in def_mesg.dev_field_defs: + data = (field_def.def_num, field_def.size, field_def.dev_data_index) + self._write_struct(data, '3B', endian=endian) + self._local_mesgs[def_mesg.header.local_mesg_num] = def_mesg + + @staticmethod + def _unapply_compressed_accumulation(raw_value, accumulation, num_bits, errmsg): + max_value = (1 << num_bits) - 1 + max_mask = max_value - 1 + + diff = raw_value - accumulation + if diff < 0 or diff > max_value: + raise FitParseError('{}: too large: {}'.format(errmsg, raw_value)) + + return raw_value & max_mask + + def _prepare_compressed_ts(self, mesg): + """Apply timestamp to header.""" + field_datas = [f for f in mesg.fields if self._is_ts_field(f)] + if len(field_datas) > 1: + raise FitParseError('Too many timestamp fields. Do not mix raw timestamp and header timestamp.') + if len(field_datas) <= 0: + return + field_data = field_datas[0] + raw_value = field_data.raw_value + if raw_value is None: + return + if not field_data.field_def: + # header timestamp + mesg.header.time_offset = self._unapply_compressed_accumulation(raw_value, + self._compressed_ts_accumulator, + 5, + 'Message header time_offset') + # raw and header timestamp field + self._compressed_ts_accumulator = raw_value + + def _write_raw_values_from_data_message(self, mesg): + field_datas = mesg.fields + def_mesg = mesg.def_mesg + for field_def in def_mesg.all_field_defs(): + base_type = field_def.base_type + is_byte = base_type.name == 'byte' + field_data = next((f for f in field_datas if f.field_def == field_def), None) + raw_value = field_data.raw_value if field_data else None + + # If the field returns with a tuple of values it's definitely an + # oddball, but we'll parse it on a per-value basis it. + # If it's a byte type, treat the tuple as a single value + if not is_byte and isinstance(raw_value, tuple): + raw_value = tuple(base_type.in_range(base_type.unparse(rv)) for rv in raw_value) + else: + # Otherwise, just scrub the singular value + raw_value = base_type.in_range(base_type.unparse(raw_value)) + size = field_def.size + if not size: + raise FitParseError('FieldDefinition has no size: {}'.format(field_def.name)) + + # Struct to write n base types (field def size / base type size) + struct_fmt = '%d%s' % ( + size / base_type.size, + base_type.fmt, + ) + try: + self._write_struct(raw_value, struct_fmt, endian=def_mesg.endian) + except struct.error as ex: + six.raise_from(FitParseError('struct.error: Wrong value or fmt for: {}, fmt: {}, value: {}'.format(field_def.name, struct_fmt, raw_value)), ex) + + def _write_data_message(self, mesg): + """Compute raw_value and size.""" + self._processor.unparse_message(mesg) + for field_data in mesg.fields: + # clear possible mess from DataMessageCreator reuse + field_data.raw_value = None + # Apply processor + self._processor.unparse_type(field_data) + self._processor.unparse_field(field_data) + self._processor.unparse_unit(field_data) + + field = field_data.field + # Sometimes raw_value is set by processor, otherwise take value. + # It's a design flaw od the library data structures. + raw_value = field_data.raw_value + if raw_value is None: + raw_value = field_data.value + + if field: + raw_value = field.unrender(raw_value) + raw_value = field.unapply_scale_offset(raw_value) + field_data.raw_value = raw_value + + self._prepare_compressed_ts(mesg) + self._write_message_header(mesg.header) + self._write_raw_values_from_data_message(mesg) + + def finish(self): + """Write header and CRC.""" + if self.completed: + return + crc = self._crc.value + self._write_file_header() + self._write_struct(crc, Crc.FMT) + self.completed = True + + def write(self, mesg): + """ + Write message. + + :param mesg: message to write + :type mesg: Union[DefinitionMessage,DataMessage] + """ + if isinstance(mesg, DataMessageCreator): + mesg = mesg.mesg + elif mesg.type == 'definition': + self._write_definition_message(mesg) + return + + def_mesg = mesg.def_mesg + if not def_mesg: + raise ValueError('mesg does not have def_mesg') + old_def_mesg = self._local_mesgs.get(mesg.header.local_mesg_num) + if old_def_mesg != def_mesg: + self._write_definition_message(def_mesg) + self._write_data_message(mesg) + + +class DataMessageCreator(object): + + def __init__(self, type_name, local_mesg_num=0, endian='<'): + """ + DataMessage creator to simplify message creatiron for the Encoder. + Use freeze() if you want to resue the DefinitionMessage and set values again. + + :param Union[str,int] type_name: message type name or number, see profile.MESSAGE_TYPES + :param int local_mesg_num: local message number + :param str endian: character '<' or '>' + """ + self.endian = endian + self.frozen = False + self.def_mesg = self._create_definition_message(type_name, local_mesg_num=local_mesg_num) + self.mesg = self._create_data_message(self.def_mesg) + + def set_value(self, name, value, size=None): + """ + Set value of given field. + + :param str name: field name + :param value: field value + :param int or None size: size of value, None for autoguess + :rtype None: + """ + field_data = self._get_or_create_field_data(name) + field_data.value = value + base_type = field_data.base_type + if size is None: + if base_type.name == 'byte': + size = len(value) if value is not None else 1 + elif base_type.name == 'string': + size = len(value) + 1 # 0x00 in the end + elif utils.is_iterable(value): + size = len(value) + else: + size = 1 + size *= base_type.size + field_def = field_data.field_def + if not self.frozen: + field_def.size = size + else: + if field_def.size != size: + raise ValueError('Frozen: cannot change field size: {}'.format(name)) + + def set_values(self, values): + """Set values. + :param Iterable[str, Any] values: iterable values in tuples (name, value). Better to use iterables with predictable order of items.""" + if values is None: + return + for name, value in values: + self.set_value(name, value) + + def set_header_timestamp(self, value): + """Set value for the compressed header timestamp (time_offset). + + :param Union[datetime.datetime,int] value: date time or number of sec (see FIT doc) + """ + field_data = self.mesg.get(profile.FIELD_TYPE_TIMESTAMP.name) + if field_data and field_data.field_def: + raise ValueError('Raw timestamp already set. Do not mix raw timestamp and header timestamp.') + if not field_data: + field_data = FieldData( + field_def=None, + field=profile.FIELD_TYPE_TIMESTAMP, + parent_field=None, + units='s' + ) + self.mesg.fields.append(field_data) + field_data.value = value + + def freeze(self): + """Freeze fields, so as the DefinitionMessage cannot change.""" + self.frozen = True + + def _create_definition_message(self, type_name, local_mesg_num=0): + """Create skeleton of new definition message. + :param Union[str,int] type_name: message type name or number, see profile.MESSAGE_TYPES + :param local_mesg_num: local message number. + :rtype DefinitionMessage + """ + if not type_name: + raise ValueError('no type_name') + mesg_type = profile.MESSAGE_TYPES.get(type_name) + if not mesg_type: + mesg_type = next((m for m in profile.MESSAGE_TYPES.values() if m.name == type_name), None) + if not mesg_type: + raise FitParseError('Message type not found: {}'.format(type_name)) + header = MessageHeader( + is_definition=True, + is_developer_data=False, + local_mesg_num=local_mesg_num + ) + + return DefinitionMessage( + header=header, + endian=self.endian, + mesg_type=mesg_type, + mesg_num=mesg_type.mesg_num, + field_defs=[], + dev_field_defs=[] + ) + + def _create_data_message(self, def_msg): + """ + Create empty data message. + + :rtype DataMessage""" + if not def_msg: + raise ValueError('No def_msg.') + + header = MessageHeader( + is_definition=False, + is_developer_data=def_msg.header.is_developer_data, + local_mesg_num=def_msg.header.local_mesg_num + ) + msg = DataMessage( + header=header, + def_mesg=def_msg, + fields=[] + ) + return msg + + def _get_or_create_field_data(self, name): + """ + + :param str name: field name + :rtype FieldData: + """ + field_data = self.mesg.get(name) + if field_data: + return field_data + if self.frozen: + raise ValueError('Frozen: cannot create FieldData: {}'.format(name)) + field_def, subfield = self._get_or_create_field_definition(name) + field = field_def.field + parent_field = None + if subfield: + parent_field = field + field = subfield + + field_data = FieldData( + field_def=field_def, + field=field, + parent_field=parent_field, + units=None + ) + self.mesg.fields.append(field_data) + return field_data + + def _get_or_create_field_definition(self, name): + """ + + :param str name: + :rtype tuple(FieldDefinition, SubField): + """ + field_def = self.def_mesg.get_field_def(name) + if field_def: + raise field_def + field, subfield = self.def_mesg.mesg_type.get_field_and_subfield(name) + if not field: + raise ValueError( + 'No field: {} in the message: {} (#{})'.format(name, self.def_mesg.name, self.def_mesg.mesg_num)) + field_def = FieldDefinition( + field=field, + def_num=field.def_num, + base_type=field.base_type + ) + self.def_mesg.field_defs.append(field_def) + return (field_def, subfield) diff --git a/fitparse/processors.py b/fitparse/processors.py index 810233f..53ab07a 100644 --- a/fitparse/processors.py +++ b/fitparse/processors.py @@ -1,11 +1,36 @@ import datetime -from fitparse.utils import scrub_method_name -# Datetimes (uint32) represent seconds since this UTC_REFERENCE -UTC_REFERENCE = 631065600 # timestamp for UTC 00:00 Dec 31 1989 +from fitparse.utils import scrub_method_name, fit_from_datetime, fit_to_datetime, fit_semicircles_to_deg -class FitFileDataProcessor(object): +class DataProcessorBase(object): + """Empty, no-op fit file data processor.""" + def run_type_processor(self, field_data): + pass + + def unparse_type(self, field_data): + pass + + def run_field_processor(self, field_data): + pass + + def unparse_field(self, field_data): + pass + + def run_unit_processor(self, field_data): + pass + + def unparse_unit(self, field_data): + pass + + def run_message_processor(self, data_message): + pass + + def unparse_message(self, data_message): + pass + + +class FitFileDataProcessor(DataProcessorBase): # TODO: Document API # Functions that will be called to do the processing: #def run_type_processor(field_data) @@ -44,19 +69,36 @@ def run_type_processor(self, field_data): self._run_processor(self._scrub_method_name( 'process_type_%s' % field_data.type.name), field_data) + def unparse_type(self, field_data): + self._run_processor(self._scrub_method_name( + 'unparse_type_%s' % field_data.type.name), field_data) + def run_field_processor(self, field_data): self._run_processor(self._scrub_method_name( 'process_field_%s' % field_data.name), field_data) + def unparse_field(self, field_data): + self._run_processor(self._scrub_method_name( + 'unparse_field_%s' % field_data.name), field_data) + def run_unit_processor(self, field_data): if field_data.units: self._run_processor(self._scrub_method_name( 'process_units_%s' % field_data.units), field_data) + def unparse_unit(self, field_data): + if field_data.units: + self._run_processor(self._scrub_method_name( + 'unparse_units_%s' % field_data.units), field_data) + def run_message_processor(self, data_message): self._run_processor(self._scrub_method_name( 'process_message_%s' % data_message.def_mesg.name), data_message) + def unparse_message(self, data_message): + self._run_processor(self._scrub_method_name( + 'unparse_message_%s' % data_message.def_mesg.name), data_message) + def _run_processor(self, processor_name, data): try: getattr(self, processor_name)(data) @@ -67,27 +109,48 @@ def process_type_bool(self, field_data): if field_data.value is not None: field_data.value = bool(field_data.value) + def unparse_type_bool(self, field_data): + if field_data.value is not None: + field_data.raw_value = int(field_data.value) + def process_type_date_time(self, field_data): value = field_data.value if value is not None and value >= 0x10000000: - field_data.value = datetime.datetime.utcfromtimestamp(UTC_REFERENCE + value) + field_data.value = fit_to_datetime(value) field_data.units = None # Units were 's', set to None + def unparse_type_date_time(self, field_data): + value = field_data.value + if value is not None and isinstance(value, datetime.datetime): + field_data.raw_value = fit_from_datetime(value) + field_data.units = 's' + def process_type_local_date_time(self, field_data): - if field_data.value is not None: + value = field_data.value + if value is not None: # NOTE: This value was created on the device using it's local timezone. # Unless we know that timezone, this value won't be correct. However, if we # assume UTC, at least it'll be consistent. - field_data.value = datetime.datetime.utcfromtimestamp(UTC_REFERENCE + field_data.value) + field_data.value = fit_to_datetime(value) field_data.units = None + def unparse_type_local_date_time(self, field_data): + self.unparse_type_date_time(field_data) + def process_type_localtime_into_day(self, field_data): - if field_data.value is not None: - m, s = divmod(field_data.value, 60) + value = field_data.value + if value is not None: + m, s = divmod(value, 60) h, m = divmod(m, 60) field_data.value = datetime.time(h, m, s) field_data.units = None + def unparse_type_localtime_into_day(self, field_data): + value = field_data.value + if value is not None and isinstance(value, datetime.time): + field_data.raw_value = value.hour * 3600 + value.minute * 60 + value.second + field_data.units = 's' + class StandardUnitsDataProcessor(FitFileDataProcessor): def run_field_processor(self, field_data): @@ -112,5 +175,5 @@ def process_field_speed(self, field_data): def process_units_semicircles(self, field_data): if field_data.value is not None: - field_data.value *= 180.0 / (2 ** 31) + field_data.value = fit_semicircles_to_deg(field_data.value) field_data.units = 'deg' diff --git a/fitparse/records.py b/fitparse/records.py index 71edc93..fa6cba3 100644 --- a/fitparse/records.py +++ b/fitparse/records.py @@ -1,12 +1,20 @@ +import itertools import math import struct -# Python 2 compat try: + # Python 2 int_types = (int, long,) + num_types = (int, float, long) + int_type = long + math_nan = float('nan') byte_iter = bytearray except NameError: + # Python 3 int_types = (int,) + num_types = (int, float) + int_type = int + math_nan = math.nan byte_iter = lambda x: x try: @@ -55,7 +63,7 @@ def name(self): return self.mesg_type.name if self.mesg_type else 'unknown_%d' % self.mesg_num def __repr__(self): - return '' % ( + return '' % ( self.name, self.mesg_num, self.header.local_mesg_num, @@ -63,6 +71,17 @@ def __repr__(self): ', '.join([fd.name for fd in self.dev_field_defs]), ) + def all_field_defs(self): + if not self.dev_field_defs: + return self.field_defs + return itertools.chain(self.field_defs, self.dev_field_defs) + + def get_field_def(self, name): + for field_def in self.all_field_defs(): + if field_def.is_named(name): + return field_def + return None + class FieldDefinition(RecordBase): __slots__ = ('field', 'def_num', 'base_type', 'size') @@ -76,13 +95,17 @@ def type(self): return self.field.type if self.field else self.base_type def __repr__(self): - return '' % ( + return '' % ( self.name, self.def_num, self.type.name, self.base_type.name, self.size, 's' if self.size != 1 else '', ) + def is_named(self, name): + return self.field.is_named(name) + + class DevFieldDefinition(RecordBase): __slots__ = ('field', 'dev_data_index', 'base_type', 'def_num', 'size') @@ -156,7 +179,7 @@ def __iter__(self): return iter(sorted(self.fields, key=lambda fd: (int(fd.field is None), fd.name))) def __repr__(self): - return '' % ( + return '' % ( self.name, self.mesg_num, self.header.local_mesg_num, ', '.join(["%s: %s" % (fd.name, fd.value) for fd in self.fields]), ) @@ -237,13 +260,25 @@ def __str__(self): ) -class BaseType(RecordBase): - __slots__ = ('name', 'identifier', 'fmt', 'parse') +class BaseType(object): + __slots__ = ('name', 'identifier', 'fmt', 'invalid_value', 'parse', 'unparse', 'in_range', '_size') values = None # In case we're treated as a FieldType + def __init__(self, name, identifier, fmt, invalid_value=None, parse=None, unparse=None, in_range=None): + self.name = name + self.identifier = identifier + self.fmt = fmt + self.invalid_value = invalid_value + self.parse = parse or self._parse + self.unparse = unparse or self._unparse + self.in_range = in_range or self._in_range + self._size = None + @property def size(self): - return struct.calcsize(self.fmt) + if self._size is None: + self._size = struct.calcsize(self.fmt) + return self._size @property def type_num(self): @@ -254,6 +289,17 @@ def __repr__(self): self.name, self.type_num, self.identifier, ) + def _parse(self, x): + return None if x == self.invalid_value else x + + def _unparse(self, x): + return self.invalid_value if x is None else x + + def _in_range(self, x): + # basic implementation for int types + return self.invalid_value if x.bit_length() > self.size * 8 else x + + class FieldType(RecordBase): __slots__ = ('name', 'base_type', 'values') @@ -268,8 +314,51 @@ class MessageType(RecordBase): def __repr__(self): return '' % (self.name, self.mesg_num) + def get_field_and_subfield(self, name): + """ + Get field by name. + :rtype tuple(Field, SubField) or tuple(Field, None) or (None, None) + """ + for field in self.fields.values(): + if field.is_named(name): + return (field, None) + if field.subfields: + subfield = next((f for f in field.subfields if f.is_named(name)), None) + if subfield: + return (field, subfield) + + return (None, None) + + +class ScaleOffsetMixin(object): + """Common methods for classes with scale and offset.""" + + def apply_scale_offset(self, raw_value): + if isinstance(raw_value, tuple): + # Contains multiple values, apply transformations to all of them + return tuple(self.apply_scale_offset(x) for x in raw_value) + elif isinstance(raw_value, num_types): + if self.scale: + raw_value = float(raw_value) / self.scale + if self.offset: + raw_value = raw_value - self.offset + return raw_value -class FieldAndSubFieldBase(RecordBase): + def unapply_scale_offset(self, value): + if isinstance(value, tuple): + # Contains multiple values, apply transformations to all of them + return tuple(self.unapply_scale_offset(x) for x in value) + elif isinstance(value, num_types): + if self.offset: + value = value + self.offset + if self.scale: + value = float(value) * self.scale + if isinstance(value, float): + value = int_type(round(value)) + return value + + +class FieldAndSubFieldBase(RecordBase, ScaleOffsetMixin): __slots__ = () @property @@ -280,9 +369,26 @@ def base_type(self): def is_base_type(self): return isinstance(self.type, BaseType) + def __repr__(self): + return '<%s: %s (#%s) -- type: %s (%s)>' % ( + self.__class__.__name__, + self.name, + self.def_num, + self.type.name, + self.base_type + ) + + def is_named(self, name): + return self.name == name or self.def_num == name + def render(self, raw_value): - if self.type.values and (raw_value in self.type.values): - return self.type.values[raw_value] + if self.type.values: + return self.type.values.get(raw_value, raw_value) + return raw_value + + def unrender(self, raw_value): + if self.type.values: + return next((k for k, v in self.type.values.items() if v == raw_value), raw_value) return raw_value @@ -307,7 +413,7 @@ class ReferenceField(RecordBase): __slots__ = ('name', 'def_num', 'value', 'raw_value') -class ComponentField(RecordBase): +class ComponentField(RecordBase, ScaleOffsetMixin): __slots__ = ('name', 'def_num', 'scale', 'offset', 'units', 'accumulate', 'bits', 'bit_offset') field_type = 'component' @@ -382,28 +488,48 @@ def calculate(cls, byte_arr, crc=0): def parse_string(string): try: end = string.index(0x00) - except TypeError: # Python 2 compat + except TypeError: # Python 2 compat end = string.index('\x00') return string[:end].decode('utf-8', errors='replace') or None + +def unparse_string(string): + if string is None: + string = '' + sbytes = string.encode('utf-8', errors='replace') + b'\0' + return sbytes + + +_FLOAT32_INVALID_VALUE = struct.unpack('f', bytes(b'\xff' * 4))[0] +_FLOAT32_MIN = -3.4028235e+38 +_FLOAT32_MAX = 3.4028235e+38 +_FLOAT64_INVALID_VALUE = struct.unpack('d', bytes(b'\xff' * 8))[0] + # The default base type -BASE_TYPE_BYTE = BaseType(name='byte', identifier=0x0D, fmt='B', parse=lambda x: None if all(b == 0xFF for b in x) else x) +BASE_TYPE_BYTE = BaseType(name='byte', identifier=0x0D, fmt='B', + parse=lambda x: None if all(b == 0xFF for b in x) else x, + unparse=lambda x: b'\xFF' if x is None else x, + in_range=lambda x: x) BASE_TYPES = { - 0x00: BaseType(name='enum', identifier=0x00, fmt='B', parse=lambda x: None if x == 0xFF else x), - 0x01: BaseType(name='sint8', identifier=0x01, fmt='b', parse=lambda x: None if x == 0x7F else x), - 0x02: BaseType(name='uint8', identifier=0x02, fmt='B', parse=lambda x: None if x == 0xFF else x), - 0x83: BaseType(name='sint16', identifier=0x83, fmt='h', parse=lambda x: None if x == 0x7FFF else x), - 0x84: BaseType(name='uint16', identifier=0x84, fmt='H', parse=lambda x: None if x == 0xFFFF else x), - 0x85: BaseType(name='sint32', identifier=0x85, fmt='i', parse=lambda x: None if x == 0x7FFFFFFF else x), - 0x86: BaseType(name='uint32', identifier=0x86, fmt='I', parse=lambda x: None if x == 0xFFFFFFFF else x), - 0x07: BaseType(name='string', identifier=0x07, fmt='s', parse=parse_string), - 0x88: BaseType(name='float32', identifier=0x88, fmt='f', parse=lambda x: None if math.isnan(x) else x), - 0x89: BaseType(name='float64', identifier=0x89, fmt='d', parse=lambda x: None if math.isnan(x) else x), - 0x0A: BaseType(name='uint8z', identifier=0x0A, fmt='B', parse=lambda x: None if x == 0x0 else x), - 0x8B: BaseType(name='uint16z', identifier=0x8B, fmt='H', parse=lambda x: None if x == 0x0 else x), - 0x8C: BaseType(name='uint32z', identifier=0x8C, fmt='I', parse=lambda x: None if x == 0x0 else x), + 0x00: BaseType(name='enum', identifier=0x00, fmt='B', invalid_value=0xFF), + 0x01: BaseType(name='sint8', identifier=0x01, fmt='b', invalid_value=0x7F), + 0x02: BaseType(name='uint8', identifier=0x02, fmt='B', invalid_value=0xFF), + 0x83: BaseType(name='sint16', identifier=0x83, fmt='h', invalid_value=0x7FFF), + 0x84: BaseType(name='uint16', identifier=0x84, fmt='H', invalid_value=0xFFFF), + 0x85: BaseType(name='sint32', identifier=0x85, fmt='i', invalid_value=0x7FFFFFFF), + 0x86: BaseType(name='uint32', identifier=0x86, fmt='I', invalid_value=0xFFFFFFFF), + 0x07: BaseType(name='string', identifier=0x07, fmt='s', parse=parse_string, unparse=unparse_string, in_range=lambda x: x), + 0x88: BaseType(name='float32', identifier=0x88, fmt='f', invalid_value=_FLOAT32_INVALID_VALUE, + parse=lambda x: None if math.isnan(x) else x, + in_range=lambda x: x if _FLOAT32_MIN < x < _FLOAT32_MAX else _FLOAT32_INVALID_VALUE), + 0x89: BaseType(name='float64', identifier=0x89, fmt='d', invalid_value=_FLOAT64_INVALID_VALUE, + parse=lambda x: None if math.isnan(x) else x, + in_range=lambda x: x), + 0x0A: BaseType(name='uint8z', identifier=0x0A, fmt='B', invalid_value=0x0), + 0x8B: BaseType(name='uint16z', identifier=0x8B, fmt='H', invalid_value=0x0), + 0x8C: BaseType(name='uint32z', identifier=0x8C, fmt='I', invalid_value=0x0), 0x0D: BASE_TYPE_BYTE, } diff --git a/fitparse/utils.py b/fitparse/utils.py index 9f4a367..f6796e4 100644 --- a/fitparse/utils.py +++ b/fitparse/utils.py @@ -1,6 +1,7 @@ -import re - +import datetime import io +import re +from collections import Iterable class FitParseError(ValueError): @@ -16,6 +17,7 @@ class FitHeaderError(FitParseError): pass +UTC_REFERENCE = datetime.datetime(1989, 12, 31) # timestamp for UTC 00:00 Dec 31 1989 METHOD_NAME_SCRUBBER = re.compile(r'\W|^(?=\d)') UNIT_NAME_TO_FUNC_REPLACEMENTS = ( ('/', ' per '), @@ -23,6 +25,27 @@ class FitHeaderError(FitParseError): ('*', ' times '), ) + +def fit_to_datetime(sec): + """Convert FIT seconds to datetime.""" + return UTC_REFERENCE + datetime.timedelta(seconds=sec) + + +def fit_from_datetime(dt): + """Convert datetime to FIT seconds.""" + return int((dt - UTC_REFERENCE).total_seconds()) + + +def fit_semicircles_to_deg(sc): + """Convert FIT semicircles to deg (for the GPS lat, long).""" + return sc * 180.0 / (2 ** 31) + + +def fit_deg_to_semicircles(deg): + """Convert deg to FIT semicircles (for the GPS lat, long).""" + return int(deg / 180.0 * (2 ** 31)) + + def scrub_method_name(method_name, convert_units=False): if convert_units: for replace_from, replace_to in UNIT_NAME_TO_FUNC_REPLACEMENTS: @@ -56,3 +79,10 @@ def fileish_open(fileish, mode): else: # Python 3 - file contents return io.BytesIO(fileish) + + +def is_iterable(obj): + """Check, if the obj is iterable but not string or bytes. + :rtype bool""" + # Speed: do not use iter() although it's more robust, see also https://stackoverflow.com/questions/1952464/ + return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) diff --git a/tests/test.py b/tests/test.py index 44f5ea2..8225ff6 100755 --- a/tests/test.py +++ b/tests/test.py @@ -7,7 +7,7 @@ import sys from fitparse import FitFile -from fitparse.processors import UTC_REFERENCE, StandardUnitsDataProcessor +from fitparse.processors import fit_to_datetime, StandardUnitsDataProcessor from fitparse.records import BASE_TYPES, Crc from fitparse.utils import FitEOFError, FitCRCError, FitHeaderError @@ -68,10 +68,6 @@ def generate_fitfile(data=None, endian='<'): return file_data + pack('<' + Crc.FMT, Crc.calculate(file_data)) -def secs_to_dt(secs): - return datetime.datetime.utcfromtimestamp(secs + UTC_REFERENCE) - - def testfile(filename): return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'files', filename) @@ -99,7 +95,7 @@ def test_basic_file_with_one_record(self, endian='<'): for field in ('serial_number', 3): self.assertEqual(file_id.get_value(field), 558069241) for field in ('time_created', 4): - self.assertEqual(file_id.get_value(field), secs_to_dt(723842606)) + self.assertEqual(file_id.get_value(field), fit_to_datetime(723842606)) self.assertEqual(file_id.get(field).raw_value, 723842606) for field in ('number', 5): self.assertEqual(file_id.get_value(field), None) diff --git a/tests/test_encoder.py b/tests/test_encoder.py new file mode 100644 index 0000000..713ff1f --- /dev/null +++ b/tests/test_encoder.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +import copy +import datetime +import io +import os +import sys + +from fitparse import FitFile +from fitparse.encoder import FitFileEncoder, DataMessageCreator + +if sys.version_info >= (2, 7): + import unittest +else: + import unittest2 as unittest + + +def testfile(filename): + return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'files', filename) + + +class FitFileEncoderTestCase(unittest.TestCase): + + def test_header(self): + file = io.BytesIO() + with FitFileEncoder(file) as fwrite: + fwrite.finish() + buff = file.getvalue() + pass + self.assertTrue(fwrite.completed) + self.assertEqual(16, len(buff)) + + with FitFile(buff) as fread: + self.assertEqual(0, len(fread.messages)) + self.assertEqual(fwrite.protocol_version, fread.protocol_version) + self.assertEqual(fwrite.profile_version, fread.profile_version) + self.assertEqual(fwrite.data_size, fread.data_size) + + def test_basic_activity_create(self): + file = io.BytesIO() + # copy of written messages + messages = [] + time_created = datetime.datetime(2017, 12, 13, 14, 15, 16) + with FitFileEncoder(file) as fwrite: + def write(mesg): + fwrite.write(mesg) + messages.append(copy.deepcopy(mesg.mesg)) + + mesg = DataMessageCreator('file_id') + mesg.set_values(( + ('serial_number', 123456), + ('manufacturer', 'dynastream'), + ('garmin_product', 'hrm1'), # test subfield + ('type', 'activity'), + ('time_created', time_created) + )) + write(mesg) + + mesg = DataMessageCreator('device_info') + mesg.set_values(( + ('manufacturer', 284), + ('product', 1), + ('product_name', 'unit test') # test string + )) + write(mesg) + + rec_mesg = DataMessageCreator('record', local_mesg_num=1) + rec_mesg.set_values(( + ('timestamp', time_created), + ('altitude', 100), + ('distance', 0) + )) + write(rec_mesg) + + rec_mesg2 = DataMessageCreator('record', local_mesg_num=2) + rec_mesg2.set_values(( + ('altitude', 102), + ('distance', 2) + )) + rec_mesg2.set_header_timestamp(time_created + datetime.timedelta(seconds=2)) + write(rec_mesg2) + + rec_mesg2.set_values(( + ('altitude', 40000), # out of sint16 range + ('distance', 4) + )) + rec_mesg2.set_header_timestamp(time_created + datetime.timedelta(seconds=4)) + write(rec_mesg2) + messages[-1].get('altitude').value = None # to conform the assert + + mesg = DataMessageCreator('session') + mesg.set_values(( + ('start_time', time_created), + ('timestamp', time_created), + ('total_distance', 20.5), + ('total_ascent', 1234), + ('total_descent', 654), + ('total_elapsed_time', 3661.5), + ('avg_altitude', 821), + ('sport', 'cycling'), + ('event', 'session'), + ('event_type', 'start') + )) + write(mesg) + + fwrite.finish() + buff = file.getvalue() + + with FitFile(buff) as fread: + rmessages = fread.messages + + self._assert_messages(messages, rmessages) + + def test_basic_activity_read_write(self): + # note: 'Activity.fit' has some useless definition messages + with FitFile(testfile('Activity.fit')) as fread: + messages = fread.messages + + file = io.BytesIO() + with FitFileEncoder(file) as fwrite: + for m in messages: + # current encoder can do just basic fields + m.fields = [f for f in m.fields if f.field_def or FitFileEncoder._is_ts_field(f)] + # need to unset raw_value + for field_data in m.fields: + field_data.raw_value = None + fwrite.write(m) + fwrite.finish() + buff = file.getvalue() + + with FitFile(buff) as fread: + messages_buff = fread.messages + + self._assert_messages(messages, messages_buff) + + def _assert_messages(self, expected, actual): + self.assertEqual(len(expected), len(actual), msg='#messages') + for emsg, amsg in zip(expected, actual): + self.assertEqual(emsg.name, amsg.name) + self._assert_message_headers(emsg.header, amsg.header) + self.assertEqual(self._get_header_ts(emsg.fields), self._get_header_ts(amsg.fields), msg='message: {} header timestamp'.format(emsg.name)) + efields = self._filter_fields_for_test(emsg.fields) + afields = self._filter_fields_for_test(amsg.fields) + self.assertEqual(len(efields), len(afields), msg='message: {} #fields'.format(emsg.name)) + for efield, afield in zip(efields, afields): + self.assertEqual(efield.name, afield.name, msg='message: {} field names'.format(emsg.name)) + self.assertEqual(efield.value, afield.value, + msg='message: {}, field: {} values'.format(emsg.name, efield.name)) + + def _assert_message_headers(self, expected, actual): + self.assertEqual(expected.is_definition, actual.is_definition) + self.assertEqual(expected.is_developer_data, actual.is_developer_data) + self.assertEqual(expected.local_mesg_num, actual.local_mesg_num) + self.assertEqual(expected.time_offset, actual.time_offset) + + @staticmethod + def _filter_fields_for_test(fields): + """Take only base field for the test.""" + return [f for f in fields if f.field_def] + + @staticmethod + def _get_header_ts(fields): + """Get timestamp related to the compressed header.""" + field_data = next((f for f in fields if f.field_def is None and FitFileEncoder._is_ts_field(f)), None) + return field_data.value if field_data else None + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_processors.py b/tests/test_processors.py new file mode 100644 index 0000000..bd035cf --- /dev/null +++ b/tests/test_processors.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +import datetime +import sys + +from fitparse import FitFileDataProcessor +from fitparse.profile import FIELD_TYPE_TIMESTAMP +from fitparse.records import FieldData + +if sys.version_info >= (2, 7): + import unittest +else: + import unittest2 as unittest + + +class ProcessorsTestCase(unittest.TestCase): + + def test_fitfiledataprocessor(self): + raw_value = 3600 + 60 + 1 + fd = FieldData( + field_def=None, + field=FIELD_TYPE_TIMESTAMP, + parent_field=None, + value=raw_value, + raw_value=raw_value, + ) + pr = FitFileDataProcessor() + # local_date_time + pr.process_type_local_date_time(fd) + self.assertEqual(datetime.datetime(1989, 12, 31, 1, 1, 1), fd.value) + pr.unparse_type_local_date_time(fd) + self.assertEqual(raw_value, fd.raw_value) + # localtime_into_day + fd.value = raw_value + fd.raw_value = None + pr.process_type_localtime_into_day(fd) + self.assertEqual(datetime.time(1, 1, 1), fd.value) + pr.unparse_type_localtime_into_day(fd) + self.assertEqual(raw_value, fd.raw_value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_records.py b/tests/test_records.py index 5e3b823..60b8d55 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -2,6 +2,7 @@ import sys +from fitparse import records from fitparse.records import Crc if sys.version_info >= (2, 7): @@ -11,6 +12,13 @@ class RecordsTestCase(unittest.TestCase): + + def test_string_parse(self): + sb = b'Test string\0' + s = records.parse_string(sb) + self.assertEqual('Test string', s) + self.assertEqual(sb, records.unparse_string(s)) + def test_crc(self): crc = Crc() self.assertEqual(0, crc.value) diff --git a/tests/test_utils.py b/tests/test_utils.py index e0b6547..8d9ede2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,12 @@ #!/usr/bin/env python - +import datetime import io import os import sys import tempfile -from fitparse.utils import fileish_open +from fitparse import utils +from fitparse.utils import fileish_open, is_iterable if sys.version_info >= (2, 7): import unittest @@ -19,6 +20,22 @@ def testfile(filename): class UtilsTestCase(unittest.TestCase): + def test_fit_to_datetime(self): + sec = 3600 + 60 + 1 + dt = datetime.datetime(1989, 12, 31, 1, 1, 1) + self.assertEqual(dt, utils.fit_from_datetime(sec)) + self.assertEqual(sec, utils.fit_to_datetime(dt)) + + def test_fit_semicircles_to_deg(self): + sc = 495280430 + deg = 41.513926070183516 + self.assertEqual(deg, utils.fit_semicircles_to_deg(sc)) + self.assertEqual(sc, utils.fit_deg_to_semicircles(deg)) + # test rounding errors + for i in range(100): + sc += 1 + self.assertEqual(sc, utils.fit_deg_to_semicircles(utils.fit_semicircles_to_deg(sc))) + def test_fileish_open_read(self): """Test the constructor does the right thing when given different types (specifically, test files with 8 characters, followed by an uppercase.FIT @@ -61,6 +78,16 @@ def test_fopen(fileish): except OSError: pass + def test_is_iterable(self): + self.assertFalse(is_iterable(None)) + self.assertFalse(is_iterable(1)) + self.assertFalse(is_iterable('1')) + self.assertFalse(is_iterable(b'1')) + + self.assertTrue(is_iterable((1, 2))) + self.assertTrue(is_iterable([1, 2])) + self.assertTrue(is_iterable(range(2))) + if __name__ == '__main__': unittest.main()