diff --git a/pkg/agent/controller/packetsampling/packetin_test.go b/pkg/agent/controller/packetsampling/packetin_test.go new file mode 100644 index 00000000000..8f79925d998 --- /dev/null +++ b/pkg/agent/controller/packetsampling/packetin_test.go @@ -0,0 +1,215 @@ +package packetsampling + +import ( + "context" + "io" + "net" + "testing" + + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcapgo" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" + "golang.org/x/time/rate" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/klog/v2" + + "antrea.io/libOpenflow/openflow15" + "antrea.io/libOpenflow/protocol" + "antrea.io/libOpenflow/util" + "antrea.io/ofnet/ofctrl" + + "antrea.io/antrea/pkg/agent/config" + openflowtest "antrea.io/antrea/pkg/agent/openflow/testing" + crdv1alpha1 "antrea.io/antrea/pkg/apis/crd/v1alpha1" +) + +const ( + MaxNum = 5 +) + +var ( + testTag = uint8(7) + testUID = "1-2-3-4" + testSFTPUrl = "sftp://10.220.175.92:22/root/packetsamplings" +) + +func getTestPacketBytes(dstIP string, dscp uint8) []byte { + ipPacket := &protocol.IPv4{ + Version: 0x4, + IHL: 5, + Protocol: uint8(8), + DSCP: dscp, + Length: 20, + NWSrc: net.IP(pod1IPv4), + NWDst: net.IP(dstIP), + } + ethernetPkt := protocol.NewEthernet() + ethernetPkt.HWSrc = pod1MAC + ethernetPkt.Ethertype = protocol.IPv4_MSG + ethernetPkt.Data = ipPacket + pktBytes, _ := ethernetPkt.MarshalBinary() + return pktBytes +} + +func generateTestPsState(name string, writer *pcapgo.NgWriter, num int32) *packetSamplingState { + return &packetSamplingState{ + name: name, + maxNumCapturedPackets: MaxNum, + numCapturedPackets: num, + tag: testTag, + pcapngWriter: writer, + shouldSyncPackets: true, + updateRateLimiter: rate.NewLimiter(rate.Every(samplingStatusUpdatePeriod), 1), + } +} + +func generatePacketSampling(name string) *crdv1alpha1.PacketSampling { + return &crdv1alpha1.PacketSampling{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + UID: types.UID(testUID), + }, + Status: crdv1alpha1.PacketSamplingStatus{ + DataplaneTag: int8(testTag), + }, + Spec: crdv1alpha1.PacketSamplingSpec{ + FirstNSamplingConfig: &crdv1alpha1.FirstNSamplingConfig{ + Number: 5, + }, + FileServer: crdv1alpha1.BundleFileServer{ + URL: testSFTPUrl, + }, + Authentication: crdv1alpha1.BundleServerAuthConfiguration{ + AuthType: crdv1alpha1.BasicAuthentication, + AuthSecret: &v1.SecretReference{ + Name: "AAA", + Namespace: "default", + }, + }, + }, + } +} + +func generateTestSecret() *v1.Secret { + return &v1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "AAA", + Namespace: "default", + }, + Data: map[string][]byte{ + "username": []byte("AAA"), + "password": []byte("BBBCCC"), + }, + } +} + +type testUploader struct { +} + +func (uploader *testUploader) upload(addr string, path string, config *ssh.ClientConfig, tarGzFile io.Reader) error { + klog.Info("Called test uploader") + return nil +} + +func TestHandlePacketSamplingPacketIn(t *testing.T) { + + invalidPktBytes := getTestPacketBytes("89.207.132.170", 0) + pktBytesPodToPod := getTestPacketBytes(pod2IPv4, testTag) + + // create test os + appFS := afero.NewMemMapFs() + appFS.MkdirAll("/tmp/packetsampling/packets", 0755) + file, err := appFS.Create(uidToPath(testUID)) + if err != nil { + t.Fatal("create pcapng file error: ", err) + } + + testWriter, err := pcapgo.NewNgWriter(file, layers.LinkTypeEthernet) + if err != nil { + t.Fatal("create test pcapng writer failed: ", err) + } + + tests := []struct { + name string + networkConfig *config.NetworkConfig + nodeConfig *config.NodeConfig + psState *packetSamplingState + pktIn *ofctrl.PacketIn + expectedPS *crdv1alpha1.PacketSampling + expectedErrStr string + expectedCalls func(mockOFClient *openflowtest.MockClient) + expectedNum int32 + }{ + { + name: "unrelated packets", + psState: generateTestPsState("ps-with-invalid-packet", testWriter, 0), + expectedPS: generatePacketSampling("ps-with-invalid-packet"), + pktIn: &ofctrl.PacketIn{ + PacketIn: &openflow15.PacketIn{ + Data: util.NewBuffer(invalidPktBytes), + }, + }, + expectedErrStr: "parsePacketIn error: PacketSampling for dataplane tag 0 not found in cache", + }, + { + name: "not hitting target number", + psState: generateTestPsState("ps-with-less-num", testWriter, 1), + expectedPS: generatePacketSampling("ps-with-less-num"), + expectedNum: 2, + pktIn: &ofctrl.PacketIn{ + PacketIn: &openflow15.PacketIn{ + Data: util.NewBuffer(pktBytesPodToPod), + }, + }, + }, + { + name: "hit target number", + psState: generateTestPsState("ps-with-max-num", testWriter, MaxNum-1), + expectedPS: generatePacketSampling("ps-with-max-num"), + expectedNum: MaxNum, + pktIn: &ofctrl.PacketIn{ + PacketIn: &openflow15.PacketIn{ + Data: util.NewBuffer(pktBytesPodToPod), + }, + }, + expectedCalls: func(mockOFClient *openflowtest.MockClient) { + mockOFClient.EXPECT().UninstallPacketSamplingFlows(testTag) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + psc := newFakePacketSamplingController(t, []runtime.Object{tt.expectedPS}, nil, &config.NodeConfig{Name: "node1"}) + if tt.expectedCalls != nil { + tt.expectedCalls(psc.mockOFClient) + } + stopCh := make(chan struct{}) + defer close(stopCh) + psc.crdInformerFactory.Start(stopCh) + psc.crdInformerFactory.WaitForCacheSync(stopCh) + psc.runningPacketSamplings[uint8(tt.expectedPS.Status.DataplaneTag)] = tt.psState + + err := psc.HandlePacketIn(tt.pktIn) + if err == nil { + assert.Equal(t, tt.expectedErrStr, "") + // check target num in status + ps, err := psc.crdClient.CrdV1alpha1().PacketSamplings().Get(context.TODO(), tt.expectedPS.Name, metav1.GetOptions{}) + assert.Nil(t, err) + assert.Equal(t, tt.expectedNum, ps.Status.NumCapturedPackets) + } else { + assert.Equal(t, tt.expectedErrStr, err.Error()) + } + + }) + } +} diff --git a/pkg/agent/controller/packetsampling/packetsampling_controller_test.go b/pkg/agent/controller/packetsampling/packetsampling_controller_test.go index bfb17445e29..143764ceaf7 100644 --- a/pkg/agent/controller/packetsampling/packetsampling_controller_test.go +++ b/pkg/agent/controller/packetsampling/packetsampling_controller_test.go @@ -88,7 +88,7 @@ type fakePacketSamplingController struct { func newFakePacketSamplingController(t *testing.T, initObjects []runtime.Object, networkConfig *config.NetworkConfig, nodeConfig *config.NodeConfig) *fakePacketSamplingController { controller := gomock.NewController(t) - kubeClient := fake.NewSimpleClientset(&pod1, &pod2, &pod3) + kubeClient := fake.NewSimpleClientset(&pod1, &pod2, &pod3, generateTestSecret()) mockOFClient := openflowtest.NewMockClient(controller) crdClient := fakeversioned.NewSimpleClientset(initObjects...) crdInformerFactory := crdinformers.NewSharedInformerFactory(crdClient, 0) @@ -100,7 +100,7 @@ func newFakePacketSamplingController(t *testing.T, initObjects []runtime.Object, _, serviceCIDRNet, _ := net.ParseCIDR("10.96.0.0/12") - tfController := &Controller{ + psController := &Controller{ kubeClient: kubeClient, crdClient: crdClient, packetSamplingInformer: packetSamplingInformer, @@ -113,10 +113,11 @@ func newFakePacketSamplingController(t *testing.T, initObjects []runtime.Object, serviceCIDR: serviceCIDRNet, queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(minRetryDelay, maxRetryDelay), "PacketSampling"), runningPacketSamplings: make(map[uint8]*packetSamplingState), + sftpUploader: &testUploader{}, } return &fakePacketSamplingController{ - Controller: tfController, + Controller: psController, kubeClient: kubeClient, mockController: controller, mockOFClient: mockOFClient,