Source code for cantools.database.can.database

import logging
from collections import OrderedDict
from typing import (
    Any,
    TextIO,
)

from ...typechecking import DecodeResultType, EncodeInputType, StringPathLike
from ..errors import DecodeError
from ..utils import (
    SORT_SIGNALS_DEFAULT,
    sort_signals_by_start_bit,
    type_sort_attributes,
    type_sort_choices,
    type_sort_signals,
)
from .bus import Bus
from .formats import arxml, dbc, kcd, sym
from .formats.arxml import AutosarDatabaseSpecifics
from .formats.dbc import DbcSpecifics
from .internal_database import InternalDatabase
from .message import Message
from .node import Node

LOGGER = logging.getLogger(__name__)


[docs] class Database: """This class contains all messages, signals and definitions of a CAN network. The factory functions :func:`load()<cantools.database.load()>`, :func:`load_file()<cantools.database.load_file()>` and :func:`load_string()<cantools.database.load_string()>` returns instances of this class. If `strict` is ``True`` an exception is raised if any signals are overlapping or if they don't fit in their message. By default signals are sorted by their start bit when their Message object is created. If you don't want them to be sorted pass `sort_signals = None`. If you want the signals to be sorted in another way pass something like `sort_signals = lambda signals: list(sorted(signals, key=lambda sig: sig.name))` """ def __init__(self, messages: list[Message] | None = None, nodes: list[Node] | None = None, buses: list[Bus] | None = None, version: str | None = None, dbc_specifics: DbcSpecifics | None = None, autosar_specifics: AutosarDatabaseSpecifics | None = None, frame_id_mask: int | None = None, strict: bool = True, sort_signals: type_sort_signals = sort_signals_by_start_bit, ) -> None: self._messages = messages or [] self._nodes = nodes or [] self._buses = buses or [] self._name_to_message: dict[str, Message] = {} self._frame_id_to_message: dict[int, Message] = {} self._version = version self._dbc = dbc_specifics self._autosar = autosar_specifics if frame_id_mask is None: frame_id_mask = 0xffffffff self._frame_id_mask = frame_id_mask self._strict = strict self._sort_signals = sort_signals self.refresh() @property def messages(self) -> list[Message]: """A list of messages in the database. Use :meth:`.get_message_by_frame_id()` or :meth:`.get_message_by_name()` to find a message by its frame id or name. """ return self._messages @property def nodes(self) -> list[Node]: """A list of nodes in the database. """ return self._nodes @property def buses(self) -> list[Bus]: """A list of CAN buses in the database. """ return self._buses @property def version(self) -> str | None: """The database version, or ``None`` if unavailable. """ return self._version @version.setter def version(self, value: str | None) -> None: self._version = value @property def dbc(self) -> DbcSpecifics | None: """An object containing dbc specific properties like e.g. attributes. """ return self._dbc @dbc.setter def dbc(self, value: DbcSpecifics | None) -> None: self._dbc = value @property def autosar(self) -> AutosarDatabaseSpecifics | None: """An object containing AUTOSAR specific properties like e.g. attributes. """ return self._autosar @autosar.setter def autosar(self, value: AutosarDatabaseSpecifics | None) -> None: self._autosar = value
[docs] def is_similar(self, other: "Database", *, tolerance: float = 1e-12, include_format_specifics: bool = True) -> bool: """Compare two database objects inexactly This means that small discrepanceies stemming from e.g. rounding errors are ignored. """ return not self._differences(other, tolerance=tolerance, include_format_specifics=include_format_specifics)
def _differences(self, other: "Database", *, tolerance: float = 1e-12, include_format_specifics: bool = True) -> list[dict[str, str]]: """Return a list of differences between this Database and `other`. Each difference is a dict with keys: ``path``, ``self``, ``other``, ``reason``. """ diffs: list[dict[str, str]] = [] seen: set[tuple[int, int]] = set() def path_to_str(path_parts: list[str]) -> str: if not path_parts: return '<root>' return ''.join(path_parts) def add_diff(path_parts: list[str], a_val: Any, b_val: Any, reason: str) -> None: diffs.append({ 'path': path_to_str(path_parts), 'self': str(a_val), 'other': str(b_val), 'reason': reason, }) def compare(a: Any, b: Any, path_parts: list[str]) -> None: pair_id = (id(a), id(b)) if pair_id in seen: return seen.add(pair_id) if type(a) is not type(b): add_diff(path_parts, type(a), type(b), 'type-mismatch') return if callable(a) or a is None: # nothing to compare return if isinstance(a, (int, str, set, bool)): if a != b: add_diff(path_parts, a, b, 'value-mismatch') return if isinstance(a, float): if abs(a) > 1: if abs(1.0 - b / a) > tolerance: add_diff(path_parts, a, b, 'float-rel-diff') elif abs(b - a) > tolerance: add_diff(path_parts, a, b, 'float-abs-diff') return if isinstance(a, (list, tuple)): if len(a) != len(b): add_diff(path_parts, len(a), len(b), 'length-mismatch') min_len = min(len(a), len(b)) for i in range(min_len): compare(a[i], b[i], [*path_parts, f'[{i}]']) return if isinstance(a, (dict, OrderedDict)): keys_a = set(a.keys()) keys_b = set(b.keys()) for k in sorted(keys_a - keys_b, key=str): add_diff([*path_parts, f'[{k!r}]'], a[k], None, 'key-only-in-a') for k in sorted(keys_b - keys_a, key=str): add_diff([*path_parts, f'[{k!r}]'], None, b[k], 'key-only-in-b') for k in sorted(keys_a & keys_b, key=str): compare(a[k], b[k], [*path_parts, f'[{k!r}]']) return # get attributes a_names = dir(a) b_names = dir(b) if not include_format_specifics: for x in ('dbc', 'autosar'): if x in a_names: a_names.remove(x) if x in b_names: b_names.remove(x) if a_names != b_names: add_diff(path_parts, sorted(a_names), sorted(b_names), 'attrib-names-mismatch') return for name in a_names: if name == 'messages' and hasattr(a, '_frame_id_to_message'): # compare messages independent of order compare(a._frame_id_to_message, b._frame_id_to_message, [*path_parts, "._frame_id_to_message"]) continue if name == 'signals' and hasattr(a, '_signal_dict'): # compare messages independent of order compare(a._signal_dict, b._signal_dict, [*path_parts, "._signal_dict"]) continue if name.startswith('_'): # skip private attributes continue a_attr = getattr(a, name) b_attr = getattr(b, name) compare(a_attr, b_attr, [*path_parts, f'.{name}']) # compare root compare(self, other, []) return diffs
[docs] def add_arxml(self, fp: TextIO) -> None: """Read and parse ARXML data from given file-like object and add the parsed data to the database. """ self.add_arxml_string(fp.read())
[docs] def add_arxml_file(self, filename: StringPathLike, encoding: str = 'utf-8') -> None: """Open, read and parse ARXML data from given file and add the parsed data to the database. `encoding` specifies the file encoding. """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_arxml(fin)
[docs] def add_arxml_string(self, string: str) -> None: """Parse given ARXML data string and add the parsed data to the database. """ database = arxml.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self._autosar = database.autosar self.refresh()
[docs] def add_dbc(self, fp: TextIO) -> None: """Read and parse DBC data from given file-like object and add the parsed data to the database. >>> db = cantools.database.Database() >>> with open ('foo.dbc', 'r') as fin: ... db.add_dbc(fin) """ self.add_dbc_string(fp.read())
[docs] def add_dbc_file(self, filename: StringPathLike, encoding: str = 'cp1252') -> None: """Open, read and parse DBC data from given file and add the parsed data to the database. `encoding` specifies the file encoding. >>> db = cantools.database.Database() >>> db.add_dbc_file('foo.dbc') """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_dbc(fin)
[docs] def add_dbc_string(self, string: str) -> None: """Parse given DBC data string and add the parsed data to the database. >>> db = cantools.database.Database() >>> with open ('foo.dbc', 'r') as fin: ... db.add_dbc_string(fin.read()) """ database = dbc.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self.refresh()
[docs] def add_kcd(self, fp: TextIO) -> None: """Read and parse KCD data from given file-like object and add the parsed data to the database. """ self.add_kcd_string(fp.read())
[docs] def add_kcd_file(self, filename: StringPathLike, encoding: str = 'utf-8') -> None: """Open, read and parse KCD data from given file and add the parsed data to the database. `encoding` specifies the file encoding. """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_kcd(fin)
[docs] def add_kcd_string(self, string: str) -> None: """Parse given KCD data string and add the parsed data to the database. """ database = kcd.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self.refresh()
[docs] def add_sym(self, fp: TextIO) -> None: """Read and parse SYM data from given file-like object and add the parsed data to the database. """ self.add_sym_string(fp.read())
[docs] def add_sym_file(self, filename: StringPathLike, encoding: str = 'utf-8') -> None: """Open, read and parse SYM data from given file and add the parsed data to the database. `encoding` specifies the file encoding. """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_sym(fin)
[docs] def add_sym_string(self, string: str) -> None: """Parse given SYM data string and add the parsed data to the database. """ database = sym.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self.refresh()
def _add_message(self, message: Message) -> None: """Add given message to the database. """ if message.name in self._name_to_message: LOGGER.warning("Overwriting message '%s' with '%s' in the " "name to message dictionary.", self._name_to_message[message.name].name, message.name) masked_frame_id = (message.frame_id & self._frame_id_mask) if message.is_extended_frame: masked_frame_id |= 0x80000000 if masked_frame_id in self._frame_id_to_message: LOGGER.warning( "Overwriting message '%s' with '%s' in the frame id to message " "dictionary because they have identical masked frame ids 0x%x.", self._frame_id_to_message[masked_frame_id].name, message.name, masked_frame_id) self._name_to_message[message.name] = message self._frame_id_to_message[masked_frame_id] = message
[docs] def as_dbc_string(self, *, sort_signals:type_sort_signals=SORT_SIGNALS_DEFAULT, sort_attribute_signals:type_sort_signals=SORT_SIGNALS_DEFAULT, sort_attributes:type_sort_attributes=None, sort_choices:type_sort_choices=None, shorten_long_names:bool=True) -> str: """Return the database as a string formatted as a DBC file. sort_signals defines how to sort signals in message definitions sort_attribute_signals defines how to sort signals in metadata - comments, value table definitions and attributes """ if not self._sort_signals and sort_signals == SORT_SIGNALS_DEFAULT: sort_signals = None return dbc.dump_string(InternalDatabase(self._messages, self._nodes, self._buses, self._version, self._dbc), sort_signals=sort_signals, sort_attribute_signals=sort_attribute_signals, sort_attributes=sort_attributes, sort_choices=sort_choices, shorten_long_names=shorten_long_names)
[docs] def as_kcd_string(self, *, sort_signals:type_sort_signals=SORT_SIGNALS_DEFAULT) -> str: """Return the database as a string formatted as a KCD file. """ if not self._sort_signals and sort_signals == SORT_SIGNALS_DEFAULT: sort_signals = None return kcd.dump_string(InternalDatabase(self._messages, self._nodes, self._buses, self._version, self._dbc), sort_signals=sort_signals)
[docs] def as_sym_string(self, *, sort_signals:type_sort_signals=SORT_SIGNALS_DEFAULT) -> str: """Return the database as a string formatted as a SYM file. """ if not self._sort_signals and sort_signals == SORT_SIGNALS_DEFAULT: sort_signals = None return sym.dump_string(InternalDatabase(self._messages, self._nodes, self._buses, self._version, self._dbc), sort_signals=sort_signals)
[docs] def get_message_by_name(self, name: str) -> Message: """Find the message object for given name `name`. """ return self._name_to_message[name]
[docs] def get_message_by_frame_id(self, frame_id: int, force_extended_id: bool = False) -> Message: """Find the message object for given frame id `frame_id`. """ if force_extended_id or frame_id > 0x7FF: frame_id |= 0x80000000 return self._frame_id_to_message[frame_id & (0x80000000 | self._frame_id_mask)]
[docs] def get_node_by_name(self, name: str) -> Node: """Find the node object for given name `name`. """ for node in self._nodes: if node.name == name: return node raise KeyError(name)
[docs] def get_bus_by_name(self, name: str) -> Bus: """Find the bus object for given name `name`. """ for bus in self._buses: if bus.name == name: return bus raise KeyError(name)
[docs] def encode_message(self, frame_id_or_name: int | str, data: EncodeInputType, scaling: bool = True, padding: bool = False, strict: bool = True, force_extended_id: bool = False, ) -> bytes: """Encode given signal data `data` as a message of given frame id or name `frame_id_or_name`. For regular Messages, `data` is a dictionary of signal name-value entries, for container messages it is a list of (ContainedMessageOrMessageName, ContainedMessageSignals) tuples. If `scaling` is ``False`` no scaling of signals is performed. If `padding` is ``True`` unused bits are encoded as 1. If `strict` is ``True`` all signal values must be within their allowed ranges, or an exception is raised. >>> db.encode_message(158, {'Bar': 1, 'Fum': 5.0}) b'\\x01\\x45\\x23\\x00\\x11' >>> db.encode_message('Foo', {'Bar': 1, 'Fum': 5.0}) b'\\x01\\x45\\x23\\x00\\x11' """ if isinstance(frame_id_or_name, int): if force_extended_id or frame_id_or_name > 0x7FF: frame_id_or_name |= 0x80000000 message = self._frame_id_to_message[frame_id_or_name] elif isinstance(frame_id_or_name, str): message = self._name_to_message[frame_id_or_name] else: raise ValueError(f"Invalid frame_id_or_name '{frame_id_or_name}'") return message.encode(data, scaling, padding, strict)
[docs] def decode_message(self, frame_id_or_name: int | str, data: bytes, decode_choices: bool = True, scaling: bool = True, decode_containers: bool = False, allow_truncated: bool = False, force_extended_id: bool = False, ) \ -> DecodeResultType: """Decode given signal data `data` as a message of given frame id or name `frame_id_or_name`. Returns a dictionary of signal name-value entries. If `decode_choices` is ``False`` scaled values are not converted to choice strings (if available). If `scaling` is ``False`` no scaling of signals is performed. >>> db.decode_message(158, b'\\x01\\x45\\x23\\x00\\x11') {'Bar': 1, 'Fum': 5.0} >>> db.decode_message('Foo', b'\\x01\\x45\\x23\\x00\\x11') {'Bar': 1, 'Fum': 5.0} If `decode_containers` is ``True``, container frames are decoded. The reason why this needs to be explicitly enabled is that decoding container frames returns a list of ``(Message, SignalsDict)`` tuples which will cause code that does not expect this to misbehave. Trying to decode a container message with `decode_containers` set to ``False`` will raise a `DecodeError`. """ if isinstance(frame_id_or_name, int): if force_extended_id or frame_id_or_name > 0x7FF: frame_id_or_name |= 0x80000000 message = self._frame_id_to_message[frame_id_or_name] elif isinstance(frame_id_or_name, str): message = self._name_to_message[frame_id_or_name] else: raise ValueError(f"Invalid frame_id_or_name '{frame_id_or_name}'") if message.is_container: if decode_containers: return message.decode(data, decode_choices, scaling, decode_containers=True, allow_truncated=allow_truncated) else: raise DecodeError(f'Message "{message.name}" is a container ' f'message, but decoding such messages has ' f'not been enabled!') return message.decode(data, decode_choices, scaling, allow_truncated=allow_truncated)
[docs] def refresh(self) -> None: """Refresh the internal database state. This method must be called after modifying any message in the database to refresh the internal lookup tables used when encoding and decoding messages. """ self._name_to_message = {} self._frame_id_to_message = {} for message in self._messages: message.refresh(self._strict) self._add_message(message)
def __repr__(self) -> str: lines = [f"version('{self._version}')", ''] if self._nodes: for node in self._nodes: lines.append(repr(node)) lines.append('') for message in self._messages: lines.append(repr(message)) for signal in message.signals: lines.append(' ' + repr(signal)) lines.append('') return '\n'.join(lines)