-
Notifications
You must be signed in to change notification settings - Fork 4
/
tcp-rl.h
107 lines (80 loc) · 2.19 KB
/
tcp-rl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#ifndef TCP_RL_H
#define TCP_RL_H
#include "ns3/tcp-congestion-ops.h"
#include "ns3/opengym-module.h"
#include "ns3/tcp-socket-base.h"
namespace ns3 {
class TcpSocketBase;
class Time;
class TcpGymEnv;
// used to get pointer to Congestion Algorithm
class TcpSocketDerived : public TcpSocketBase
{
public:
static TypeId GetTypeId (void);
virtual TypeId GetInstanceTypeId () const;
TcpSocketDerived (void);
virtual ~TcpSocketDerived (void);
Ptr<TcpCongestionOps> GetCongestionControlAlgorithm ();
};
class TcpRlBase : public TcpCongestionOps
{
public:
/**
* \brief Get the type ID.
* \return the object TypeId
*/
static TypeId GetTypeId (void);
TcpRlBase ();
/**
* \brief Copy constructor.
* \param sock object to copy.
*/
TcpRlBase (const TcpRlBase& sock);
~TcpRlBase ();
virtual std::string GetName () const;
virtual uint32_t GetSsThresh (Ptr<const TcpSocketState> tcb, uint32_t bytesInFlight);
virtual void IncreaseWindow (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked);
virtual void PktsAcked (Ptr<TcpSocketState> tcb, uint32_t segmentsAcked, const Time& rtt);
virtual void CongestionStateSet (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCongState_t newState);
virtual void CwndEvent (Ptr<TcpSocketState> tcb, const TcpSocketState::TcpCAEvent_t event);
virtual Ptr<TcpCongestionOps> Fork ();
protected:
static uint64_t GenerateUuid ();
virtual void CreateGymEnv();
void ConnectSocketCallbacks();
// OpenGymEnv interface
Ptr<TcpSocketBase> m_tcpSocket;
Ptr<TcpGymEnv> m_tcpGymEnv;
};
class TcpRl : public TcpRlBase
{
public:
static TypeId GetTypeId (void);
TcpRl ();
TcpRl (const TcpRl& sock);
~TcpRl ();
virtual std::string GetName () const;
private:
virtual void CreateGymEnv();
// OpenGymEnv env
float m_reward {1.0};
float m_penalty {-100.0};
};
class TcpRlTimeBased : public TcpRlBase
{
public:
static TypeId GetTypeId (void);
TcpRlTimeBased ();
TcpRlTimeBased (const TcpRlTimeBased& sock);
~TcpRlTimeBased ();
virtual std::string GetName () const;
private:
virtual void CreateGymEnv();
Time m_duration;
Time m_timeStep;
float m_reward;
float m_penalty;
};
} // namespace ns3
#endif /* TCP_RL_H */