-
Notifications
You must be signed in to change notification settings - Fork 0
/
communication.py
121 lines (92 loc) · 3.14 KB
/
communication.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from math import nan
import struct
from dataclasses import dataclass, field
from typing import List
from serial import Serial
from enum import IntEnum
from datetime import timedelta
class PacketType(IntEnum):
InvalidPacket = 0
IterationData = 1
BenchmarkData = 2
@dataclass
class Packet:
type: PacketType = PacketType.InvalidPacket
length: int = 0
def unpackHeader(self, serial: Serial):
data = serial.read(2*2)
type, length = struct.unpack('HH', data)
self.type = type
self.length = length
@staticmethod
def readFromSerial(serial: Serial):
packetBase = Packet()
packetBase.unpackHeader(serial)
PacketClass = packetMap.get(packetBase.type, None)
if not PacketClass:
return
packet = PacketClass()
packet.type = packetBase.type
packet.length = packetBase.length
packet.unpackFromSerial(serial, unpackHeader=False)
return packet
@dataclass
class ArduinoIterationPacket(Packet):
tau: float = 0
iteration: int = 0
nIterations: int = 0
nNodes: int = 0
nodeSize: int = 0
nodes: List['Node'] = field(default_factory=list)
@dataclass
class Node:
t: float = 0
x: float = 0
def unpack(self, data: bytes):
t, x = struct.unpack('ff', data)
self.t = t
self.x = x
def unpackFromSerial(self, serial: Serial, unpackHeader = True) -> None:
self.nodes.clear()
if unpackHeader:
self.unpackHeader(serial)
data = serial.read(self.length)
tau, iteration, nIterations, nNodes, nodeSize = struct.unpack('fhHHH', data)
self.tau = tau
self.iteration = iteration
self.nIterations = nIterations
self.nNodes = nNodes
self.nodeSize = nodeSize
if self.nodeSize * self.nNodes == 0:
return
data = serial.read(self.nodeSize * self.nNodes)
for i in range(0, len(data), self.nodeSize):
node = ArduinoIterationPacket.Node()
node.unpack(data[i:i+self.nodeSize])
self.nodes.append(node)
def totalMicrosecons(diff: timedelta) -> int:
micros = diff.days * 24 * 60 * 60 * 1000000
micros += diff.seconds * 1000000
micros += diff.microseconds
return micros
@dataclass
class ArduinoBenchmarkPacket(Packet):
arduinoMicrosStart: int = nan
arduinoMicrosEnd: int = nan
def unpackFromSerial(self, serial: Serial, unpackHeader = True):
if unpackHeader:
self.unpackHeader(serial)
microsStart, microsEnd = struct.unpack('LL', serial.read(self.length))
# if self.pcTimeStart and not self.pcTimeEnd:
# self.pcTimeEnd = datetime.now()
# if not self.pcTimeStart and not self.pcTimeEnd:
# self.pcTimeStart = datetime.now()
self.arduinoMicrosStart = microsStart
self.arduinoMicrosEnd = microsEnd
@property
def arduinoTimeElapsedMicros(self):
return self.arduinoMicrosEnd - self.arduinoMicrosStart
packetMap = {
PacketType.IterationData: ArduinoIterationPacket,
PacketType.BenchmarkData: ArduinoBenchmarkPacket
}