transport.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. #encoding=utf-8
  2. import time
  3. import math
  4. import ctypes
  5. import struct
  6. import logging
  7. import binascii
  8. import functools
  9. import collections
  10. import uavcan
  11. import uavcan.dsdl as dsdl
  12. import uavcan.dsdl.common as common
  13. def bits_from_bytes(s):
  14. return "".join(format(c, "08b") for c in s)
  15. def bytes_from_bits(s):
  16. return bytearray(int(s[i:i+8], 2) for i in xrange(0, len(s), 8))
  17. def be_from_le_bits(s, bitlen):
  18. if len(s) < bitlen:
  19. raise ValueError("Not enough bits; need {0} but got {1}".format(
  20. bitlen, len(s)))
  21. elif len(s) > bitlen:
  22. s = s[0:bitlen]
  23. return "".join([s[i:i + 8] for i in xrange(0, len(s), 8)][::-1])
  24. def le_from_be_bits(s, bitlen):
  25. if len(s) < bitlen:
  26. raise ValueError("Not enough bits; need {0} but got {1}".format(
  27. bitlen, len(s)))
  28. elif len(s) > bitlen:
  29. s = s[len(s) - bitlen:]
  30. return "".join([s[max(0, i - 8):i] for i in xrange(len(s), 0, -8)])
  31. def format_bits(s):
  32. return " ".join(s[i:i+8] for i in xrange(0, len(s), 8))
  33. def union_tag_len(x):
  34. return int(math.ceil(math.log(len(x), 2))) or 1
  35. # http://davidejones.com/blog/1413-python-precision-floating-point/
  36. def f16_from_f32(float32):
  37. F16_EXPONENT_BITS = 0x1F
  38. F16_EXPONENT_SHIFT = 10
  39. F16_EXPONENT_BIAS = 15
  40. F16_MANTISSA_BITS = 0x3ff
  41. F16_MANTISSA_SHIFT = (23 - F16_EXPONENT_SHIFT)
  42. F16_MAX_EXPONENT = (F16_EXPONENT_BITS << F16_EXPONENT_SHIFT)
  43. a = struct.pack('>f', float32)
  44. b = binascii.hexlify(a)
  45. f32 = int(b, 16)
  46. f16 = 0
  47. sign = (f32 >> 16) & 0x8000
  48. exponent = ((f32 >> 23) & 0xff) - 127
  49. mantissa = f32 & 0x007fffff
  50. if exponent == 128:
  51. f16 = sign | F16_MAX_EXPONENT
  52. if mantissa:
  53. f16 |= (mantissa & F16_MANTISSA_BITS)
  54. elif exponent > 15:
  55. f16 = sign | F16_MAX_EXPONENT
  56. elif exponent > -15:
  57. exponent += F16_EXPONENT_BIAS
  58. mantissa >>= F16_MANTISSA_SHIFT
  59. f16 = sign | exponent << F16_EXPONENT_SHIFT | mantissa
  60. else:
  61. f16 = sign
  62. return f16
  63. # http://davidejones.com/blog/1413-python-precision-floating-point/
  64. def f32_from_f16(float16):
  65. t1 = float16 & 0x7FFF
  66. t2 = float16 & 0x8000
  67. t3 = float16 & 0x7C00
  68. t1 <<= 13
  69. t2 <<= 16
  70. t1 += 0x38000000
  71. t1 = 0 if t3 == 0 else t1
  72. t1 |= t2
  73. return struct.unpack("<f", struct.pack("<L", t1))[0]
  74. def cast(value, dtype):
  75. if dtype.cast_mode == dsdl.parser.PrimitiveType.CAST_MODE_SATURATED:
  76. if value > dtype.value_range[1]:
  77. value = dtype.value_range[1]
  78. elif value < dtype.value_range[0]:
  79. value = dtype.value_range[0]
  80. return value
  81. elif (dtype.cast_mode == dsdl.parser.PrimitiveType.CAST_MODE_TRUNCATED and
  82. dtype.kind == dsdl.parser.PrimitiveType.KIND_FLOAT):
  83. if not isnan(value) and value > dtype.value_range[1]:
  84. value = float("+inf")
  85. elif not isnan(value) and value < dtype.value_range[0]:
  86. value = float("-inf")
  87. return value
  88. elif dtype.cast_mode == dsdl.parser.PrimitiveType.CAST_MODE_TRUNCATED:
  89. return value & ((1 << dtype.bitlen) - 1)
  90. else:
  91. raise ValueError("Invalid cast_mode: " + repr(dtype))
  92. class Void(object):
  93. def __init__(self, bitlen):
  94. self.bitlen = bitlen
  95. def unpack(self, stream):
  96. return stream[self.bitlen:]
  97. def pack(self):
  98. return "0" * self.bitlen
  99. class BaseValue(object):
  100. def __init__(self, uavcan_type, *args, **kwargs):
  101. self.type = uavcan_type
  102. self._bits = None
  103. def unpack(self, stream):
  104. if self.type.bitlen:
  105. self._bits = be_from_le_bits(stream, self.type.bitlen)
  106. return stream[self.type.bitlen:]
  107. else:
  108. return stream
  109. def pack(self):
  110. if self._bits:
  111. return le_from_be_bits(self._bits, self.type.bitlen)
  112. else:
  113. return "0" * self.type.bitlen
  114. class PrimitiveValue(BaseValue):
  115. def __repr__(self):
  116. return repr(self.value)
  117. @property
  118. def value(self):
  119. if not self._bits:
  120. return None
  121. int_value = int(self._bits, 2)
  122. if self.type.kind == dsdl.parser.PrimitiveType.KIND_BOOLEAN:
  123. return int_value
  124. elif self.type.kind == dsdl.parser.PrimitiveType.KIND_UNSIGNED_INT:
  125. return int_value
  126. elif self.type.kind == dsdl.parser.PrimitiveType.KIND_SIGNED_INT:
  127. if int_value >= (1 << (self.type.bitlen - 1)):
  128. int_value = -((1 << self.type.bitlen) - int_value)
  129. return int_value
  130. elif self.type.kind == dsdl.parser.PrimitiveType.KIND_FLOAT:
  131. if self.type.bitlen == 16:
  132. return f32_from_f16(int_value)
  133. elif self.type.bitlen == 32:
  134. return struct.unpack("<f", struct.pack("<L", int_value))[0]
  135. else:
  136. raise ValueError("Only 16- or 32-bit floats are supported")
  137. @value.setter
  138. def value(self, new_value):
  139. if new_value is None:
  140. raise ValueError("Can't serialize a None value")
  141. elif self.type.kind == dsdl.parser.PrimitiveType.KIND_BOOLEAN:
  142. self._bits = "1" if new_value else "0"
  143. elif self.type.kind == dsdl.parser.PrimitiveType.KIND_UNSIGNED_INT:
  144. new_value = cast(new_value, self.type)
  145. self._bits = format(new_value, "0" + str(self.type.bitlen) + "b")
  146. elif self.type.kind == dsdl.parser.PrimitiveType.KIND_SIGNED_INT:
  147. new_value = cast(new_value, self.type)
  148. self._bits= format(new_value, "0" + str(self.type.bitlen) + "b")
  149. elif self.type.kind == dsdl.parser.PrimitiveType.KIND_FLOAT:
  150. new_value = cast(new_value, self.type)
  151. if self.type.bitlen == 16:
  152. int_value = f16_from_f32(new_value)
  153. elif self.type.bitlen == 32:
  154. int_value = \
  155. struct.unpack("<L", struct.pack("<f", new_value))[0]
  156. else:
  157. raise ValueError("Only 16- or 32-bit floats are supported")
  158. self._bits = format(int_value, "0" + str(self.type.bitlen) + "b")
  159. class ArrayValue(BaseValue, collections.MutableSequence):
  160. def __init__(self, uavcan_type, tao=False, *args, **kwargs):
  161. super(ArrayValue, self).__init__(uavcan_type, *args, **kwargs)
  162. value_bitlen = getattr(self.type.value_type, "bitlen", None)
  163. self._tao = tao if value_bitlen >= 8 else False
  164. if isinstance(self.type.value_type, dsdl.parser.PrimitiveType):
  165. self.__item_ctor = functools.partial(PrimitiveValue,
  166. self.type.value_type)
  167. elif isinstance(self.type.value_type, dsdl.parser.ArrayType):
  168. self.__item_ctor = functools.partial(ArrayValue,
  169. self.type.value_type)
  170. elif isinstance(self.type.value_type, dsdl.parser.CompoundType):
  171. self.__item_ctor = functools.partial(CompoundValue,
  172. self.type.value_type)
  173. if self.type.mode == dsdl.parser.ArrayType.MODE_STATIC:
  174. self.__items = list(self.__item_ctor()
  175. for i in xrange(self.type.max_size))
  176. else:
  177. self.__items = []
  178. def __repr__(self):
  179. return "ArrayValue(type={0!r}, tao={1!r}, items={2!r})".format(
  180. self.type, self._tao, self.__items)
  181. def __str__(self):
  182. return self.__repr__()
  183. def __getitem__(self, idx):
  184. if isinstance(self.__items[idx], PrimitiveValue):
  185. return self.__items[idx].value if self.__items[idx]._bits else 0
  186. else:
  187. return self.__items[idx]
  188. def __setitem__(self, idx, value):
  189. if idx >= self.type.max_size:
  190. raise IndexError(("Index {0} too large (max size " +
  191. "{1})").format(idx, self.type.max_size))
  192. if isinstance(self.type.value_type, dsdl.parser.PrimitiveType):
  193. self.__items[idx].value = value
  194. else:
  195. self.__items[idx] = value
  196. def __delitem__(self, idx):
  197. del self.__items[idx]
  198. def __len__(self):
  199. return len(self.__items)
  200. def new_item(self):
  201. return self.__item_ctor()
  202. def insert(self, idx, value):
  203. if idx >= self.type.max_size:
  204. raise IndexError(("Index {0} too large (max size " +
  205. "{1})").format(idx, self.type.max_size))
  206. elif len(self) == self.type.max_size:
  207. raise IndexError(("Array already full (max size "
  208. "{0})").format(self.type.max_size))
  209. if isinstance(self.type.value_type, dsdl.parser.PrimitiveType):
  210. new_item = self.__item_ctor()
  211. new_item.value = value
  212. self.__items.insert(idx, new_item)
  213. else:
  214. self.__items.insert(idx, value)
  215. def unpack(self, stream):
  216. if self.type.mode == dsdl.parser.ArrayType.MODE_STATIC:
  217. for i in xrange(self.type.max_size):
  218. stream = self.__items[i].unpack(stream)
  219. elif self._tao:
  220. del self[:]
  221. while len(stream) >= 8:
  222. new_item = self.__item_ctor()
  223. stream = new_item.unpack(stream)
  224. self.__items.append(new_item)
  225. stream = ""
  226. else:
  227. del self[:]
  228. count_width = int(math.ceil(math.log(self.type.max_size, 2))) or 1
  229. count = int(stream[0:count_width], 2)
  230. stream = stream[count_width:]
  231. for i in xrange(count):
  232. new_item = self.__item_ctor()
  233. stream = new_item.unpack(stream)
  234. self.__items.append(new_item)
  235. return stream
  236. def pack(self):
  237. if self.type.mode == dsdl.parser.ArrayType.MODE_STATIC:
  238. items = "".join(i.pack() for i in self.__items)
  239. if len(self) < self.type.max_size:
  240. empty_item = self.__item_ctor()
  241. items += "".join(empty_item.pack() for i in
  242. xrange(self.type.max_size - len(self)))
  243. return items
  244. elif self._tao:
  245. return "".join(i.pack() for i in self.__items)
  246. else:
  247. count_width = int(math.ceil(math.log(self.type.max_size, 2))) or 1
  248. count = format(len(self), "0{0:1d}b".format(count_width))
  249. return count + "".join(i.pack() for i in self.__items)
  250. def from_bytes(self, value):
  251. del self[:]
  252. for byte in bytearray(value):
  253. self.append(byte)
  254. def to_bytes(self):
  255. return bytes(bytearray(item.value for item in self.__items
  256. if item._bits))
  257. def encode(self, value):
  258. del self[:]
  259. value = bytearray(value, encoding="utf-8")
  260. for byte in value:
  261. self.append(byte)
  262. def decode(self, encoding="utf-8"):
  263. return bytearray(item.value for item in self.__items
  264. if item._bits).decode(encoding)
  265. class CompoundValue(BaseValue):
  266. def __init__(self, uavcan_type, mode=None, tao=False, *args, **kwargs):
  267. self.__dict__["fields"] = collections.OrderedDict()
  268. self.__dict__["constants"] = {}
  269. super(CompoundValue, self).__init__(uavcan_type, *args, **kwargs)
  270. self.mode = mode
  271. self.data_type_id = self.type.default_dtid
  272. self.crc_base = ""
  273. source_fields = None
  274. source_constants = None
  275. is_union = False
  276. if self.type.kind == dsdl.parser.CompoundType.KIND_SERVICE:
  277. if self.mode == "request":
  278. source_fields = self.type.request_fields
  279. source_constants = self.type.request_constants
  280. is_union = self.type.request_union
  281. elif self.mode == "response":
  282. source_fields = self.type.response_fields
  283. source_constants = self.type.response_constants
  284. is_union = self.type.response_union
  285. else:
  286. raise ValueError("mode must be either 'request' or " +
  287. "'response' for service types")
  288. else:
  289. source_fields = self.type.fields
  290. source_constants = self.type.constants
  291. is_union = self.type.union
  292. self.is_union = is_union
  293. self.union_field = None
  294. for constant in source_constants:
  295. self.constants[constant.name] = constant.value
  296. for idx, field in enumerate(source_fields):
  297. atao = field is source_fields[-1] and tao
  298. if isinstance(field.type, dsdl.parser.VoidType):
  299. self.fields["_void_{0}".format(idx)] = Void(field.type.bitlen)
  300. elif isinstance(field.type, dsdl.parser.PrimitiveType):
  301. self.fields[field.name] = PrimitiveValue(field.type)
  302. elif isinstance(field.type, dsdl.parser.ArrayType):
  303. self.fields[field.name] = ArrayValue(field.type, tao=atao)
  304. elif isinstance(field.type, dsdl.parser.CompoundType):
  305. self.fields[field.name] = CompoundValue(field.type, tao=atao)
  306. def __repr__(self):
  307. if self.is_union:
  308. field = self.union_field or self.fields.keys()[0]
  309. fields = "{0}={1!r}".format(field, self.fields[field])
  310. else:
  311. fields = ", ".join("{0}={1!r}".format(f, v)
  312. for f, v in self.fields.items()
  313. if not f.startswith("_void_"))
  314. return "{0}({1})".format(self.type.full_name, fields)
  315. def __getattr__(self, attr):
  316. if attr in self.constants:
  317. return self.constants[attr]
  318. elif attr in self.fields:
  319. if self.is_union:
  320. if self.union_field and self.union_field != attr:
  321. raise AttributeError(attr)
  322. else:
  323. self.union_field = attr
  324. if isinstance(self.fields[attr], PrimitiveValue):
  325. return self.fields[attr].value
  326. else:
  327. return self.fields[attr]
  328. else:
  329. raise AttributeError(attr)
  330. def __setattr__(self, attr, value):
  331. if attr in self.constants:
  332. raise AttributeError(attr + " is read-only")
  333. elif attr in self.fields:
  334. if self.is_union:
  335. if self.union_field and self.union_field != attr:
  336. raise AttributeError(attr)
  337. else:
  338. self.union_field = attr
  339. if isinstance(self.fields[attr].type,
  340. dsdl.parser.PrimitiveType):
  341. self.fields[attr].value = value
  342. else:
  343. raise AttributeError(attr + " cannot be set directly")
  344. else:
  345. super(CompoundValue, self).__setattr__(attr, value)
  346. def unpack(self, stream):
  347. if self.is_union:
  348. tag_len = union_tag_len(self.fields)
  349. self.union_field = self.fields.keys()[int(stream[0:tag_len], 2)]
  350. stream = self.fields[self.union_field].unpack(stream[tag_len:])
  351. else:
  352. for field in self.fields.itervalues():
  353. stream = field.unpack(stream)
  354. return stream
  355. def pack(self):
  356. if self.is_union:
  357. field = self.union_field or self.fields.keys()[0]
  358. tag = self.fields.keys().index(field)
  359. return format(tag, "0" + str(union_tag_len(self.fields)) + "b") +\
  360. self.fields[field].pack()
  361. else:
  362. return "".join(field.pack() for field in self.fields.itervalues())
  363. class Frame(object):
  364. def __init__(self, message_id, bytes):
  365. self.message_id = message_id
  366. self.bytes = bytearray(bytes)
  367. @property
  368. def transfer_key(self):
  369. # The transfer is uniquely identified by the message ID and the 5-bit
  370. # Transfer ID contained in the last byte of the frame payload.
  371. return (self.message_id,
  372. (self.bytes[-1] & 0x1F) if self.bytes else None)
  373. @property
  374. def toggle(self):
  375. return bool(self.bytes[-1] & 0x20) if self.bytes else 0
  376. @property
  377. def end_of_transfer(self):
  378. return bool(self.bytes[-1] & 0x40) if self.bytes else False
  379. @property
  380. def start_of_transfer(self):
  381. return bool(self.bytes[-1] & 0x80) if self.bytes else False
  382. class Transfer(object):
  383. def __init__(self, transfer_id=0, source_node_id=0, data_type_id=0,
  384. dest_node_id=None, payload=0, transfer_priority=31,
  385. request_not_response=False, service_not_message=False,
  386. discriminator=None):
  387. self.transfer_priority = transfer_priority
  388. self.transfer_id = transfer_id
  389. self.source_node_id = source_node_id
  390. self.data_type_id = data_type_id
  391. self.dest_node_id = dest_node_id
  392. self.data_type_signature = 0
  393. self.request_not_response = request_not_response
  394. self.service_not_message = service_not_message
  395. if payload:
  396. payload_bits = payload.pack()
  397. if len(payload_bits) & 7:
  398. payload_bits += "0" * (8 - (len(payload_bits) & 7))
  399. self.payload = bytes_from_bits(payload_bits)
  400. self.data_type_id = payload.type.default_dtid
  401. self.data_type_signature = payload.type.get_data_type_signature()
  402. self.data_type_crc = payload.type.base_crc
  403. else:
  404. self.payload = None
  405. self.data_type_id = None
  406. self.data_type_signature = None
  407. self.data_type_crc = None
  408. self.is_complete = True if self.payload else False
  409. def __repr__(self):
  410. return ("Transfer(id={0}, source_node_id={1}, dest_node_id={2}, "
  411. "transfer_priority={3}, payload={4!r})").format(
  412. self.transfer_id, self.source_node_id, self.dest_node_id,
  413. self.transfer_priority, self.payload)
  414. @property
  415. def message_id(self):
  416. # Common fields
  417. id_ = (((self.transfer_priority & 0x1F) << 24) |
  418. (int(self.service_not_message) << 7) |
  419. (self.source_node_id or 0))
  420. if self.service_not_message:
  421. assert 0 <= self.data_type_id <= 0xFF
  422. assert 1 <= self.dest_node_id <= 0x7F
  423. # Service frame format
  424. id_ |= self.data_type_id << 16
  425. id_ |= int(self.request_not_response) << 15
  426. id_ |= self.dest_node_id << 8
  427. elif self.source_node_id == 0:
  428. assert self.dest_node_id is None
  429. assert self.discriminator is not None
  430. # Anonymous message frame format
  431. id_ |= self.discriminator << 10
  432. id_ |= (self.data_type_id & 0x3) << 8
  433. else:
  434. assert 0 <= self.data_type_id <= 0xFFFF
  435. # Message frame format
  436. id_ |= self.data_type_id << 8
  437. return id_
  438. @message_id.setter
  439. def message_id(self, value):
  440. self.transfer_priority = (value >> 24) & 0x1F
  441. self.service_not_message = bool(value & 0x80)
  442. self.source_node_id = value & 0x7F
  443. if self.service_not_message:
  444. self.data_type_id = (value >> 16) & 0xFF
  445. self.request_not_response = bool(value & 0x8000)
  446. self.dest_node_id = (value >> 8) & 0x7F
  447. elif self.source_node_id == 0:
  448. self.discriminator = (value >> 10) & 0x3FFF
  449. self.data_type_id = (value >> 8) & 0x3
  450. else:
  451. self.data_type_id = (value >> 8) & 0xFFFF
  452. def to_frames(self):
  453. out_frames = []
  454. remaining_payload = self.payload
  455. # Prepend the transfer CRC to the payload if the transfer requires
  456. # multiple frames
  457. if len(remaining_payload) > 7:
  458. crc = common.crc16_from_bytes(self.payload,
  459. initial=self.data_type_crc)
  460. remaining_payload = bytearray([crc & 0xFF, crc >> 8]) + \
  461. remaining_payload
  462. # Generate the frame sequence
  463. tail = 0x20 # set toggle bit high so the first frame is emitted with
  464. # it cleared
  465. while True:
  466. # Tail byte contains start-of-transfer, end-of-transfer, toggle,
  467. # and Transfer ID
  468. tail = ((0x80 if len(out_frames) == 0 else 0) |
  469. (0x40 if len(remaining_payload) <= 7 else 0) |
  470. ((tail ^ 0x20) & 0x20) |
  471. (self.transfer_id & 0x1F))
  472. out_frames.append(Frame(message_id=self.message_id,
  473. bytes=remaining_payload[0:7] +
  474. bytearray(chr(tail))))
  475. remaining_payload = remaining_payload[7:]
  476. if not remaining_payload:
  477. break
  478. return out_frames
  479. def from_frames(self, frames):
  480. # Validate the flags in the tail byte
  481. expected_toggle = 0
  482. expected_transfer_id = frames[0].bytes[-1] & 0x1F
  483. for idx, f in enumerate(frames):
  484. tail = f.bytes[-1]
  485. if (tail & 0x1F) != expected_transfer_id:
  486. raise ValueError(("Transfer ID {0} incorrect, expected " +
  487. "{1}").format(
  488. tail & 0x1F, expected_transfer_id))
  489. elif idx == 0 and not (tail & 0x80):
  490. raise ValueError("Start of transmission not set on frame 0")
  491. elif idx > 0 and tail & 0x80:
  492. raise ValueError(("Start of transmission set unexpectedly " +
  493. "on frame {0}").format(idx))
  494. elif idx == len(frames) - 1 and not (tail & 0x40):
  495. raise ValueError("End of transmission not set on last frame")
  496. elif idx < len(frames) - 1 and (tail & 0x40):
  497. raise ValueError(("End of transmission set unexpectedly " +
  498. "on frame {0}").format(idx))
  499. elif (tail & 0x20) != expected_toggle:
  500. raise ValueError(("Toggle bit value {0} incorrect on frame " +
  501. "{1}").format(tail & 0x20, idx))
  502. expected_toggle ^= 0x20
  503. self.transfer_id = expected_transfer_id
  504. self.message_id = frames[0].message_id
  505. payload_bytes = sum((f.bytes[0:-1] for f in frames), bytearray())
  506. # Find the data type
  507. if self.service_not_message:
  508. kind = dsdl.parser.CompoundType.KIND_SERVICE
  509. else:
  510. kind = dsdl.parser.CompoundType.KIND_MESSAGE
  511. datatype = uavcan.DATATYPES.get((self.data_type_id, kind))
  512. if not datatype:
  513. raise ValueError("Unrecognised {0} type ID {1}".format(
  514. "service" if self.service_not_message
  515. else "message",
  516. self.data_type_id))
  517. # For a multi-frame transfer, validate the CRC and frame indexes
  518. if len(frames) > 1:
  519. transfer_crc = payload_bytes[0] + (payload_bytes[1] << 8)
  520. payload_bytes = payload_bytes[2:]
  521. crc = common.crc16_from_bytes(payload_bytes,
  522. initial=datatype.base_crc)
  523. if crc != transfer_crc:
  524. raise ValueError(("CRC mismatch: expected {0:x}, got {1:x} " +
  525. "for payload {2!r} (DTID {3:d})").format(
  526. crc, transfer_crc, payload_bytes,
  527. self.data_type_id))
  528. self.data_type_id = datatype.default_dtid
  529. self.data_type_signature = datatype.get_data_type_signature()
  530. self.data_type_crc = datatype.base_crc
  531. if self.service_not_message:
  532. self.payload = datatype(
  533. mode="request" if self.request_not_response else "response")
  534. else:
  535. self.payload = datatype()
  536. self.payload.unpack(bits_from_bytes(payload_bytes))
  537. @property
  538. def key(self):
  539. return (self.message_id, self.transfer_id)
  540. def is_response_to(self, transfer):
  541. if (transfer.service_not_message and self.service_not_message and
  542. transfer.request_not_response and
  543. not self.request_not_response and
  544. transfer.dest_node_id == self.source_node_id and
  545. transfer.source_node_id == self.dest_node_id and
  546. transfer.data_type_id == self.data_type_id):
  547. return True
  548. else:
  549. return False
  550. class TransferManager(object):
  551. def __init__(self):
  552. self.active_transfers = collections.defaultdict(list)
  553. self.active_transfer_timestamps = {}
  554. def receive_frame(self, frame):
  555. result = None
  556. key = frame.transfer_key
  557. if key in self.active_transfers or frame.start_of_transfer:
  558. self.active_transfers[key].append(frame)
  559. self.active_transfer_timestamps[key] = time.time()
  560. # If the last frame of a transfer was received, return its frames
  561. if frame.end_of_transfer:
  562. result = self.active_transfers[key]
  563. del self.active_transfers[key]
  564. del self.active_transfer_timestamps[key]
  565. return result
  566. def remove_inactive_transfers(self, timeout=1.0):
  567. t = time.time()
  568. transfer_keys = self.active_transfers.keys()
  569. for key in transfer_keys:
  570. if t - self.active_transfer_timestamps[key] > timeout:
  571. del self.active_transfers[key]
  572. del self.active_transfer_timestamps[key]