diff --git a/api/dataplane/v1/service/service_cluster_end.go b/api/dataplane/v1/service/service_cluster_end.go index 04b464e..42ce48f 100644 --- a/api/dataplane/v1/service/service_cluster_end.go +++ b/api/dataplane/v1/service/service_cluster_end.go @@ -2,6 +2,7 @@ package service import ( "context" + "io" "net" "sync" "time" @@ -17,17 +18,17 @@ import ( "google.golang.org/grpc/credentials/insecure" ) -type frontierNservice struct { +type frontierNend struct { frontier *clusterv1.Frontier - service Service + end *serviceEnd } type serviceClusterEnd struct { *delegate.UnimplementedDelegate cc clusterv1.ClusterServiceClient - bimap *mapmap.BiMap // bidirectional edgeID and frontierID - frontiers sync.Map // key: frontierID; value: frontierNservice + edgefrontiers *mapmap.BiMap // bidirectional edgeID and frontierID + frontiers sync.Map // key: frontierID; value: frontierNservice // options *serviceOption @@ -52,7 +53,7 @@ func newServiceClusterEnd(addr string, opts ...ServiceOption) (*serviceClusterEn } cc := clusterv1.NewClusterServiceClient(conn) - serviceClusterEnd := &serviceClusterEnd{ + end := &serviceClusterEnd{ cc: cc, serviceOption: &serviceOption{}, rpcs: map[string]geminio.RPC{}, @@ -61,71 +62,69 @@ func newServiceClusterEnd(addr string, opts ...ServiceOption) (*serviceClusterEn acceptMsgCh: make(chan geminio.Message, 128), closed: make(chan struct{}), } - serviceClusterEnd.serviceOption.delegate = serviceClusterEnd + end.serviceOption.delegate = end for _, opt := range opts { - opt(serviceClusterEnd.serviceOption) + opt(end.serviceOption) } - if serviceClusterEnd.serviceOption.logger == nil { - serviceClusterEnd.serviceOption.logger = armlog.DefaultLog + if end.serviceOption.logger == nil { + end.serviceOption.logger = armlog.DefaultLog } - return serviceClusterEnd, nil + return end, nil } -func (service *serviceClusterEnd) start() { +func (end *serviceClusterEnd) start() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: - err := service.update() + err := end.update() if err != nil { - service.logger.Warnf("cluster update err: %s", err) + end.logger.Warnf("cluster update err: %s", err) continue } - case <-service.closed: + case <-end.closed: return } } } -func (service *serviceClusterEnd) clear(frontierID string) { - service.updating.Lock() - defer service.updating.Unlock() +func (end *serviceClusterEnd) clear(frontierID string) { + end.updating.Lock() + defer end.updating.Unlock() - frontier, ok := service.frontiers.LoadAndDelete(frontierID) + frontier, ok := end.frontiers.LoadAndDelete(frontierID) if ok { - frontier.(*frontierNservice).service.Close() + frontier.(*frontierNend).end.Close() } // clear map for edgeID and frontierID + end.edgefrontiers.DelValue(frontierID) } -func (service *serviceClusterEnd) update() error { - rsp, err := service.cc.ListFrontiers(context.TODO(), &clusterv1.ListFrontiersRequest{}) +func (end *serviceClusterEnd) update() error { + rsp, err := end.cc.ListFrontiers(context.TODO(), &clusterv1.ListFrontiersRequest{}) if err != nil { - service.logger.Errorf("list frontiers err: %s", err) + end.logger.Errorf("list frontiers err: %s", err) return err } - service.updating.Lock() - defer service.updating.Unlock() - keeps := []string{} - removes := []Service{} + removes := []*frontierNend{} - service.frontiers.Range(func(key, value interface{}) bool { + end.frontiers.Range(func(key, value interface{}) bool { frontierID := key.(string) - frontierNservice := value.(*frontierNservice) + frontierNend := value.(*frontierNend) for _, frontier := range rsp.Frontiers { - if frontierEqual(frontierNservice.frontier, frontier) { + if frontierEqual(frontierNend.frontier, frontier) { keeps = append(keeps, frontierID) return true } } // out of date frontier - service.logger.Debugf("frontier: %v needs to be removed", key) - service.frontiers.Delete(key) - removes = append(removes, frontierNservice.service) + end.logger.Debugf("frontier: %v needs to be removed", key) + end.frontiers.Delete(key) + removes = append(removes, frontierNend) return true }) @@ -144,27 +143,80 @@ FOUND: // aysnc connect and close go func() { for _, remove := range removes { - remove.Close() + remove.end.Close() + // clear unavaiable frontier and it's edges + end.edgefrontiers.DelValue(remove.frontier.FrontierId) } for _, new := range news { - serviceEnd, err := service.newServiceEnd(new.AdvertisedSbAddr) + serviceEnd, err := end.newServiceEnd(new.AdvertisedSbAddr) if err != nil { - service.logger.Errorf("new service end err: %s", err) + end.logger.Errorf("new service end err: %s", err) continue } // new frontier - prev, ok := service.frontiers.Swap(new.FrontierId, &frontierNservice{ + prev, ok := end.frontiers.Swap(new.FrontierId, &frontierNend{ frontier: new, - service: serviceEnd, + end: serviceEnd, }) if ok { - prev.(*frontierNservice).service.Close() + prev.(*frontierNend).end.Close() } } }() return nil } +func (end *serviceClusterEnd) lookup(edgeID uint64) (string, *serviceEnd, error) { + var ( + frontier *clusterv1.Frontier + serviceEnd *serviceEnd + err error + ) + frontierID, ok := end.edgefrontiers.GetValue(edgeID) + // get or set edgeID to frontierID map + if !ok { + rsp, err := end.cc.GetFrontierByEdge(context.TODO(), &clusterv1.GetFrontierByEdgeIDRequest{ + EdgeId: edgeID, + }) + if err != nil { + end.logger.Errorf("get frontier by edge: %d err: %s", edgeID, err) + return "", nil, err + } + frontier = rsp.Fontier + frontierID = frontier.FrontierId + end.edgefrontiers.Set(edgeID, frontierID) + } + + fe, ok := end.frontiers.Load(frontierID) + if !ok { + serviceEnd, err = end.newServiceEnd(frontier.AdvertisedSbAddr) + if err != nil { + end.logger.Errorf("new service end err: %s while lookup", err) + return "", nil, err + } + found, ok := end.frontiers.Swap(frontierID, &frontierNend{ + frontier: frontier, + end: serviceEnd, + }) + if ok { + found.(*frontierNend).end.Close() + } + } else { + serviceEnd = fe.(*frontierNend).end + } + return frontierID.(string), serviceEnd, nil +} + +func (end *serviceClusterEnd) pickone() *serviceEnd { + var serviceEnd *serviceEnd + end.frontiers.Range(func(_, value interface{}) bool { + // return first one + serviceEnd = value.(*frontierNend).end + return false + }) + return serviceEnd +} + func frontierEqual(a, b *clusterv1.Frontier) bool { return a.AdvertisedSbAddr == b.AdvertisedEbAddr && a.FrontierId == b.FrontierId @@ -217,3 +269,134 @@ ERR: serviceEnd.Close() return nil, err } + +// multiplexer +func (end *serviceClusterEnd) AcceptStream() (geminio.Stream, error) { + st, ok := <-end.acceptStreamCh + if !ok { + return nil, io.EOF + } + return st, nil +} + +func (end *serviceClusterEnd) OpenStream(ctx context.Context, edgeID uint64) (geminio.Stream, error) { + frontierID, serviceEnd, err := end.lookup(edgeID) + if err != nil { + return nil, err + } + stream, err := serviceEnd.OpenStream(ctx, edgeID) + if err != nil { + end.clear(frontierID) + return stream, err + } + return stream, nil +} + +func (end *serviceClusterEnd) ListStreams() []geminio.Stream { + streams := []geminio.Stream{} + end.frontiers.Range(func(_, value interface{}) bool { + sts := value.(*frontierNend).end.ListStreams() + if sts != nil { + streams = append(streams, sts...) + } + return true + }) + return streams +} + +// Messager +func (end *serviceClusterEnd) NewMessage(data []byte) geminio.Message { + serviceEnd := end.pickone() + if serviceEnd == nil { + return nil + } + return serviceEnd.NewMessage(data) +} + +func (end *serviceClusterEnd) Publish(ctx context.Context, edgeID uint64, msg geminio.Message) error { + fronterID, serviceEnd, err := end.lookup(edgeID) + if err != nil { + return err + } + err = serviceEnd.Publish(ctx, edgeID, msg) + if err != nil { + end.clear(fronterID) + return err + } + return nil +} + +func (end *serviceClusterEnd) PublishAsync(ctx context.Context, edgeID uint64, msg geminio.Message, ch chan *geminio.Publish) (*geminio.Publish, error) { + fronterID, serviceEnd, err := end.lookup(edgeID) + if err != nil { + return nil, err + } + pub, err := serviceEnd.PublishAsync(ctx, edgeID, msg, ch) + if err != nil { + end.clear(fronterID) + return nil, err + } + return pub, err +} + +func (end *serviceClusterEnd) Receive(ctx context.Context) (geminio.Message, error) { + msg, ok := <-end.acceptMsgCh + if !ok { + return nil, io.EOF + } + return msg, nil +} + +// RPCer +func (end *serviceClusterEnd) NewRequest(data []byte) geminio.Request { + serviceEnd := end.pickone() + if serviceEnd == nil { + return nil + } + return serviceEnd.NewRequest(data) +} + +func (end *serviceClusterEnd) Call(ctx context.Context, edgeID uint64, method string, req geminio.Request) (geminio.Response, error) { + fronterID, serviceEnd, err := end.lookup(edgeID) + if err != nil { + return nil, err + } + rsp, err := serviceEnd.Call(ctx, edgeID, method, req) + if err != nil { + end.clear(fronterID) + return nil, err + } + return rsp, nil +} + +func (end *serviceClusterEnd) CallAsync(ctx context.Context, edgeID uint64, method string, req geminio.Request, ch chan *geminio.Call) (*geminio.Call, error) { + fronterID, serviceEnd, err := end.lookup(edgeID) + if err != nil { + return nil, err + } + call, err := serviceEnd.CallAsync(ctx, edgeID, method, req, ch) + if err != nil { + end.clear(fronterID) + return nil, err + } + return call, nil +} + +func (end *serviceClusterEnd) Register(ctx context.Context, method string, rpc geminio.RPC) error { + end.appMtx.Lock() + end.rpcs[method] = rpc + end.appMtx.Unlock() + + var ( + err error + ) + // TODO optimize it + end.frontiers.Range(func(key, value interface{}) bool { + err = value.(*frontierNend).end.Register(ctx, method, rpc) + if err != nil { + return false + } + return true + }) + return err +} diff --git a/api/dataplane/v1/service/service_end.go b/api/dataplane/v1/service/service_end.go index 7368006..3964fdf 100644 --- a/api/dataplane/v1/service/service_end.go +++ b/api/dataplane/v1/service/service_end.go @@ -58,8 +58,8 @@ func newServiceEnd(dialer client.Dialer, opts ...ServiceOption) (*serviceEnd, er } // Control Register -func (service *serviceEnd) RegisterGetEdgeID(ctx context.Context, getEdgeID GetEdgeID) error { - return service.End.Register(ctx, apis.RPCGetEdgeID, func(ctx context.Context, req geminio.Request, rsp geminio.Response) { +func (end *serviceEnd) RegisterGetEdgeID(ctx context.Context, getEdgeID GetEdgeID) error { + return end.End.Register(ctx, apis.RPCGetEdgeID, func(ctx context.Context, req geminio.Request, rsp geminio.Response) { id, err := getEdgeID(req.Data()) if err != nil { // we just deliver the err back @@ -73,8 +73,8 @@ func (service *serviceEnd) RegisterGetEdgeID(ctx context.Context, getEdgeID GetE }) } -func (service *serviceEnd) RegisterEdgeOnline(ctx context.Context, edgeOnline EdgeOnline) error { - return service.End.Register(ctx, apis.RPCEdgeOnline, func(ctx context.Context, req geminio.Request, rsp geminio.Response) { +func (end *serviceEnd) RegisterEdgeOnline(ctx context.Context, edgeOnline EdgeOnline) error { + return end.End.Register(ctx, apis.RPCEdgeOnline, func(ctx context.Context, req geminio.Request, rsp geminio.Response) { on := &apis.OnEdgeOnline{} err := json.Unmarshal(req.Data(), on) if err != nil { @@ -92,8 +92,8 @@ func (service *serviceEnd) RegisterEdgeOnline(ctx context.Context, edgeOnline Ed }) } -func (service *serviceEnd) RegisterEdgeOffline(ctx context.Context, edgeOffline EdgeOffline) error { - return service.End.Register(ctx, apis.RPCEdgeOffline, func(ctx context.Context, req geminio.Request, rsp geminio.Response) { +func (end *serviceEnd) RegisterEdgeOffline(ctx context.Context, edgeOffline EdgeOffline) error { + return end.End.Register(ctx, apis.RPCEdgeOffline, func(ctx context.Context, req geminio.Request, rsp geminio.Response) { off := &apis.OnEdgeOffline{} err := json.Unmarshal(req.Data(), off) if err != nil { @@ -110,11 +110,11 @@ func (service *serviceEnd) RegisterEdgeOffline(ctx context.Context, edgeOffline } // RPCer -func (service *serviceEnd) NewRequest(data []byte) geminio.Request { - return service.End.NewRequest(data) +func (end *serviceEnd) NewRequest(data []byte) geminio.Request { + return end.End.NewRequest(data) } -func (service *serviceEnd) Call(ctx context.Context, edgeID uint64, method string, req geminio.Request) (geminio.Response, error) { +func (end *serviceEnd) Call(ctx context.Context, edgeID uint64, method string, req geminio.Request) (geminio.Response, error) { // we append the likely short one to slice tail := make([]byte, 8) binary.BigEndian.PutUint64(tail, edgeID) @@ -127,7 +127,7 @@ func (service *serviceEnd) Call(ctx context.Context, edgeID uint64, method strin req.SetCustom(custom) // call real end - rsp, err := service.End.Call(ctx, method, req) + rsp, err := end.End.Call(ctx, method, req) if err != nil { return nil, err } @@ -136,7 +136,7 @@ func (service *serviceEnd) Call(ctx context.Context, edgeID uint64, method strin } // It's just like the go rpc way -func (service *serviceEnd) CallAsync(ctx context.Context, edgeID uint64, method string, req geminio.Request, ch chan *geminio.Call) (*geminio.Call, error) { +func (end *serviceEnd) CallAsync(ctx context.Context, edgeID uint64, method string, req geminio.Request, ch chan *geminio.Call) (*geminio.Call, error) { // we append the likely short one to slice // the last 8 bytes is for frontier tail := make([]byte, 8) @@ -150,7 +150,7 @@ func (service *serviceEnd) CallAsync(ctx context.Context, edgeID uint64, method req.SetCustom(custom) // call real end - call, err := service.End.CallAsync(ctx, method, req, ch) + call, err := end.End.CallAsync(ctx, method, req, ch) if err != nil { return nil, err } @@ -158,7 +158,7 @@ func (service *serviceEnd) CallAsync(ctx context.Context, edgeID uint64, method return call, nil } -func (service *serviceEnd) Register(ctx context.Context, method string, rpc geminio.RPC) error { +func (end *serviceEnd) Register(ctx context.Context, method string, rpc geminio.RPC) error { wrap := func(_ context.Context, req geminio.Request, rsp geminio.Response) { custom := req.Custom() if len(custom) < 8 { @@ -171,15 +171,15 @@ func (service *serviceEnd) Register(ctx context.Context, method string, rpc gemi rpc(ctx, req, rsp) return } - return service.End.Register(ctx, method, wrap) + return end.End.Register(ctx, method, wrap) } // Messager -func (service *serviceEnd) NewMessage(data []byte) geminio.Message { - return service.End.NewMessage(data) +func (end *serviceEnd) NewMessage(data []byte) geminio.Message { + return end.End.NewMessage(data) } -func (service *serviceEnd) Publish(ctx context.Context, edgeID uint64, msg geminio.Message) error { +func (end *serviceEnd) Publish(ctx context.Context, edgeID uint64, msg geminio.Message) error { tail := make([]byte, 8) binary.BigEndian.PutUint64(tail, edgeID) custom := msg.Custom() @@ -191,13 +191,13 @@ func (service *serviceEnd) Publish(ctx context.Context, edgeID uint64, msg gemin msg.SetCustom(custom) // publish real end - err := service.End.Publish(ctx, msg) + err := end.End.Publish(ctx, msg) msg.SetClientID(edgeID) // TODO we need to set EdgeID to let user know return err } -func (service *serviceEnd) PublishAsync(ctx context.Context, edgeID uint64, msg geminio.Message, ch chan *geminio.Publish) (*geminio.Publish, error) { +func (end *serviceEnd) PublishAsync(ctx context.Context, edgeID uint64, msg geminio.Message, ch chan *geminio.Publish) (*geminio.Publish, error) { tail := make([]byte, 8) binary.BigEndian.PutUint64(tail, edgeID) custom := msg.Custom() @@ -209,13 +209,13 @@ func (service *serviceEnd) PublishAsync(ctx context.Context, edgeID uint64, msg msg.SetCustom(custom) // publish async - pub, err := service.End.PublishAsync(ctx, msg, ch) + pub, err := end.End.PublishAsync(ctx, msg, ch) // TODO we need to set EdgeID to let user know return pub, err } -func (service *serviceEnd) Receive(ctx context.Context) (geminio.Message, error) { - msg, err := service.End.Receive(ctx) +func (end *serviceEnd) Receive(ctx context.Context) (geminio.Message, error) { + msg, err := end.End.Receive(ctx) if err != nil { return nil, err } @@ -232,21 +232,21 @@ func (service *serviceEnd) Receive(ctx context.Context) (geminio.Message, error) } // Multiplexer -func (service *serviceEnd) OpenStream(ctx context.Context, edgeID uint64) (geminio.Stream, error) { +func (end *serviceEnd) OpenStream(ctx context.Context, edgeID uint64) (geminio.Stream, error) { id := strconv.FormatUint(edgeID, 10) opt := options.OpenStream() opt.SetPeer(id) - return service.End.OpenStream(opt) + return end.End.OpenStream(opt) } -func (service *serviceEnd) AcceptStream() (geminio.Stream, error) { - return service.End.AcceptStream() +func (end *serviceEnd) AcceptStream() (geminio.Stream, error) { + return end.End.AcceptStream() } -func (service *serviceEnd) ListStreams() []geminio.Stream { - return service.End.ListStreams() +func (end *serviceEnd) ListStreams() []geminio.Stream { + return end.End.ListStreams() } -func (service *serviceEnd) Close() error { - return service.End.Close() +func (end *serviceEnd) Close() error { + return end.End.Close() } diff --git a/pkg/mapmap/bimap.go b/pkg/mapmap/bimap.go index 4473b6f..9927288 100644 --- a/pkg/mapmap/bimap.go +++ b/pkg/mapmap/bimap.go @@ -2,16 +2,18 @@ package mapmap import "sync" +// 1 value: n keys +// 1 key: 1 value type BiMap struct { mtx sync.RWMutex kv map[any]any - vk map[any]any + vk map[any]map[any]struct{} } func NewBiMap() *BiMap { return &BiMap{ kv: map[any]any{}, - vk: map[any]any{}, + vk: map[any]map[any]struct{}{}, } } @@ -20,5 +22,69 @@ func (bm *BiMap) Set(key, value any) { defer bm.mtx.Unlock() bm.kv[key] = value - bm.vk[value] = key + ks, ok := bm.vk[value] + if ok { + ks = map[any]struct{}{} + } + ks[key] = struct{}{} + bm.vk[value] = ks +} + +func (bm *BiMap) GetValue(key any) (any, bool) { + bm.mtx.RLock() + defer bm.mtx.RUnlock() + + value, ok := bm.kv[key] + return value, ok +} + +func (bm *BiMap) Del(key any) bool { + bm.mtx.Lock() + defer bm.mtx.Unlock() + + value, ok := bm.kv[key] + if !ok { + return false + } + delete(bm.kv, key) + ks, ok := bm.vk[value] + if ok { + delete(ks, key) + if len(ks) == 0 { + delete(bm.vk, value) + } else { + bm.vk[value] = ks + } + } + return true +} + +func (bm *BiMap) GetKeys(value any) ([]any, bool) { + bm.mtx.RLock() + defer bm.mtx.RUnlock() + + ks, ok := bm.vk[value] + if !ok { + return nil, false + } + slice := []any{} + for _, k := range ks { + slice = append(slice, k) + } + return slice, true +} + +func (bm *BiMap) DelValue(value any) bool { + bm.mtx.Lock() + defer bm.mtx.Unlock() + + ks, ok := bm.vk[value] + if !ok { + return false + } + delete(bm.vk, value) + for _, k := range ks { + delete(bm.kv, k) + } + return true }