Skip to content

Commit

Permalink
Fix a hard to find problem in DeviceRedirection
Browse files Browse the repository at this point in the history
Specifc file transfers would hang. A minimal reproducer was hard to create. It turns out that the issue was in the encapsulation of the TLSSecurityLayer. If a virtual channel packet at a specific length (0x80), it would confuse the next layer into thinking it was the Security Header's licensing bytes causing the payload to be skipped from its usual processing.

See the recv() method of the TLSSecurityLayer class in `pyrdp/layer/rdp/security.py`.

Turns out that the security layer is not required if we are using modern RDP access mechanisms (anything more recent than RC4) so we can just remove the layer when the MITM is setup.

This fix happens to fix #139 as well.
  • Loading branch information
obilodeau committed Nov 15, 2022
1 parent 240b6d0 commit 7f67368
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions pyrdp/mitm/RDPMITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,18 +330,22 @@ def buildClipboardChannel(self, client: MCSServerChannel, server: MCSClientChann
:param server: MCS channel for the server side
"""

clientSecurity = self.state.createSecurityLayer(ParserMode.SERVER, True)
clientVirtualChannel = VirtualChannelLayer()
clientLayer = ClipboardLayer()
serverSecurity = self.state.createSecurityLayer(ParserMode.CLIENT, True)
serverVirtualChannel = VirtualChannelLayer()
serverLayer = ClipboardLayer()

clientLayer.addObserver(LayerLogger(self.getClientLog(MCSChannelName.CLIPBOARD)))
serverLayer.addObserver(LayerLogger(self.getServerLog(MCSChannelName.CLIPBOARD)))

LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer)
LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer)
if self.state.useTLS:
LayerChainItem.chain(client, clientVirtualChannel, clientLayer)
LayerChainItem.chain(server, serverVirtualChannel, serverLayer)
else:
clientSecurity = self.state.createSecurityLayer(ParserMode.SERVER, True)
serverSecurity = self.state.createSecurityLayer(ParserMode.CLIENT, True)
LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer)
LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer)

if self.config.disableActiveClipboardStealing:
mitm = PassiveClipboardStealer(self.config, clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD),
Expand All @@ -358,18 +362,22 @@ def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel)
:param server: MCS channel for the server side
"""

clientSecurity = self.state.createSecurityLayer(ParserMode.SERVER, True)
clientVirtualChannel = VirtualChannelLayer(activateShowProtocolFlag=False)
clientLayer = DeviceRedirectionLayer()
serverSecurity = self.state.createSecurityLayer(ParserMode.CLIENT, True)
serverVirtualChannel = VirtualChannelLayer(activateShowProtocolFlag=False)
serverLayer = DeviceRedirectionLayer()

clientLayer.addObserver(LayerLogger(self.getClientLog(MCSChannelName.DEVICE_REDIRECTION)))
serverLayer.addObserver(LayerLogger(self.getServerLog(MCSChannelName.DEVICE_REDIRECTION)))

LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer)
LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer)
if self.state.useTLS:
LayerChainItem.chain(client, clientVirtualChannel, clientLayer)
LayerChainItem.chain(server, serverVirtualChannel, serverLayer)
else:
clientSecurity = self.state.createSecurityLayer(ParserMode.SERVER, True)
serverSecurity = self.state.createSecurityLayer(ParserMode.CLIENT, True)
LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer)
LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer)

deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.statCounter, self.state, self.tcp)
self.channelMITMs[client.channelID] = deviceRedirection
Expand All @@ -387,13 +395,17 @@ def buildVirtualChannel(self, client: MCSServerChannel, server: MCSClientChannel
:param server: MCS channel for the server side
"""

clientSecurity = self.state.createSecurityLayer(ParserMode.SERVER, True)
clientLayer = RawLayer()
serverSecurity = self.state.createSecurityLayer(ParserMode.CLIENT, True)
serverLayer = RawLayer()

LayerChainItem.chain(client, clientSecurity, clientLayer)
LayerChainItem.chain(server, serverSecurity, serverLayer)
if self.state.useTLS:
LayerChainItem.chain(client, clientLayer)
LayerChainItem.chain(server, serverLayer)
else:
clientSecurity = self.state.createSecurityLayer(ParserMode.SERVER, True)
serverSecurity = self.state.createSecurityLayer(ParserMode.CLIENT, True)
LayerChainItem.chain(client, clientSecurity, clientLayer)
LayerChainItem.chain(server, serverSecurity, serverLayer)

mitm = VirtualChannelMITM(clientLayer, serverLayer, self.statCounter)
self.channelMITMs[client.channelID] = mitm
Expand Down

0 comments on commit 7f67368

Please sign in to comment.