From 2e4014ae1c84a5bde4381959c0ee8a79de32fdbc Mon Sep 17 00:00:00 2001 From: Winlin Date: Mon, 9 Sep 2024 12:06:02 +0800 Subject: [PATCH] Proxy: Support proxy server for SRS. v7.0.16 (#4158) Please note that the proxy server is a new architecture or the next version of the Origin Cluster, which allows the publication of multiple streams. The SRS origin cluster consists of a group of origin servers designed to handle a large number of streams. ```text +-----------------------+ +---+ SRS Proxy(Deployment) +------+---------------------+ +-----------------+ | +-----------+-----------+ + + | LB(K8s Service) +--+ +(Redis/MESH) + SRS Origin Servers + +-----------------+ | +-----------+-----------+ + (Deployment) + +---+ SRS Proxy(Deployment) +------+---------------------+ +-----------------------+ ``` The new origin cluster is designed as a collection of proxy servers. For more information, see [Discussion #3634](https://github.com/ossrs/srs/discussions/3634). If you prefer to use the old origin cluster, please switch to a version before SRS 6.0. A proxy server can be used for a set of origin servers, which are isolated and dedicated origin servers. The main improvement in the new architecture is to store the state for origin servers in the proxy server, rather than using MESH to communicate between origin servers. With a proxy server, you can deploy origin servers as stateless servers, such as in a Kubernetes (K8s) deployment. Now that the proxy server is a stateful server, it uses Redis to store the states. For faster development, we use Go to develop the proxy server, instead of C/C++. Therefore, the proxy server itself is also stateless, with all states stored in the Redis server or cluster. This makes the new origin cluster architecture very powerful and robust. The proxy server is also an architecture designed to solve multiple process bottlenecks. You can run hundreds of SRS origin servers with one proxy server on the same machine. This solution can utilize multi-core machines, such as servers with 128 CPUs. Thus, we can keep SRS single-threaded and very simple. See https://github.com/ossrs/srs/discussions/3665#discussioncomment-6474441 for details. ```text +--------------------+ +-------+ SRS Origin Server + + +--------------------+ + +-----------------------+ + +--------------------+ + SRS Proxy(Deployment) +------+-------+ SRS Origin Server + +-----------------------+ + +--------------------+ + + +--------------------+ +-------+ SRS Origin Server + +--------------------+ ``` Keep in mind that the proxy server for the Origin Cluster is designed to handle many streams. To address the issue of many viewers, we will enhance the Edge Cluster to support more protocols. ```text +------------------+ +--------------------+ + SRS Edge Server +--+ +-------+ SRS Origin Server + +------------------+ + + +--------------------+ + + +------------------+ + +-----------------------+ + +--------------------+ + SRS Edge Server +--+-----+ SRS Proxy(Deployment) +------+-------+ SRS Origin Server + +------------------+ + +-----------------------+ + +--------------------+ + + +------------------+ + + +--------------------+ + SRS Edge Server +--+ +-------+ SRS Origin Server + +------------------+ +--------------------+ ``` With the new Origin Cluster and Edge Cluster, you have a media system capable of supporting a large number of streams and viewers. For example, you can publish 10,000 streams, each with 100,000 viewers. --------- Co-authored-by: Jacob Su --- proxy/.gitignore | 4 + proxy/Makefile | 23 + proxy/api.go | 272 ++++ proxy/debug.go | 20 + proxy/env.go | 197 +++ proxy/errors/errors.go | 270 ++++ proxy/errors/stack.go | 187 +++ proxy/go.mod | 13 + proxy/go.sum | 17 + proxy/http.go | 419 ++++++ proxy/logger/context.go | 43 + proxy/logger/log.go | 87 ++ proxy/main.go | 121 ++ proxy/rtc.go | 515 ++++++++ proxy/rtmp.go | 655 ++++++++++ proxy/rtmp/amf0.go | 771 +++++++++++ proxy/rtmp/rtmp.go | 1792 ++++++++++++++++++++++++++ proxy/signal.go | 44 + proxy/srs.go | 553 ++++++++ proxy/srt.go | 574 +++++++++ proxy/sync/map.go | 45 + proxy/utils.go | 276 ++++ proxy/version.go | 27 + trunk/conf/origin1-for-proxy.conf | 57 + trunk/conf/origin2-for-proxy.conf | 57 + trunk/conf/origin3-for-proxy.conf | 57 + trunk/doc/CHANGELOG.md | 1 + trunk/src/app/srs_app_st.cpp | 7 +- trunk/src/core/srs_core_version7.hpp | 2 +- 29 files changed, 7104 insertions(+), 2 deletions(-) create mode 100644 proxy/.gitignore create mode 100644 proxy/Makefile create mode 100644 proxy/api.go create mode 100644 proxy/debug.go create mode 100644 proxy/env.go create mode 100644 proxy/errors/errors.go create mode 100644 proxy/errors/stack.go create mode 100644 proxy/go.mod create mode 100644 proxy/go.sum create mode 100644 proxy/http.go create mode 100644 proxy/logger/context.go create mode 100644 proxy/logger/log.go create mode 100644 proxy/main.go create mode 100644 proxy/rtc.go create mode 100644 proxy/rtmp.go create mode 100644 proxy/rtmp/amf0.go create mode 100644 proxy/rtmp/rtmp.go create mode 100644 proxy/signal.go create mode 100644 proxy/srs.go create mode 100644 proxy/srt.go create mode 100644 proxy/sync/map.go create mode 100644 proxy/utils.go create mode 100644 proxy/version.go create mode 100644 trunk/conf/origin1-for-proxy.conf create mode 100644 trunk/conf/origin2-for-proxy.conf create mode 100644 trunk/conf/origin3-for-proxy.conf diff --git a/proxy/.gitignore b/proxy/.gitignore new file mode 100644 index 0000000000..c20f4b6782 --- /dev/null +++ b/proxy/.gitignore @@ -0,0 +1,4 @@ +.idea +srs-proxy +.env +.go-formarted \ No newline at end of file diff --git a/proxy/Makefile b/proxy/Makefile new file mode 100644 index 0000000000..29084d5b76 --- /dev/null +++ b/proxy/Makefile @@ -0,0 +1,23 @@ +.PHONY: all build test fmt clean run + +all: build + +build: fmt ./srs-proxy + +./srs-proxy: *.go + go build -o srs-proxy . + +test: + go test ./... + +fmt: ./.go-formarted + +./.go-formarted: *.go + touch .go-formarted + go fmt ./... + +clean: + rm -f srs-proxy .go-formarted + +run: fmt + go run . diff --git a/proxy/api.go b/proxy/api.go new file mode 100644 index 0000000000..04baa92526 --- /dev/null +++ b/proxy/api.go @@ -0,0 +1,272 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "strings" + "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, +// to proxy other HTTP API of SRS like the streams and clients, etc. +type srsHTTPAPIServer struct { + // The underlayer HTTP server. + server *http.Server + // The WebRTC server. + rtc *srsWebRTCServer + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer { + v := &srsHTTPAPIServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsHTTPAPIServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *srsHTTPAPIServer) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // The WebRTC WHIP API handler. + logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) + mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + + // The WebRTC WHEP API handler. + logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) + mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + + // Run HTTP API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP API accept err %+v", err) + } else { + logger.Df(ctx, "HTTP API server done") + } + } + }() + + return nil +} + +// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service +// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter +// for Prometheus metrics. +type systemAPI struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI { + v := &systemAPI{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *systemAPI) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *systemAPI) Run(ctx context.Context) error { + // Parse address to listen. + addr := envSystemAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "System API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // The register service for SRS media servers. + logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr) + mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) { + if err := func() error { + var deviceID, ip, serverID, serviceID, pid string + var rtmp, stream, api, srt, rtc []string + if err := ParseBody(r.Body, &struct { + // The IP of SRS, mandatory. + IP *string `json:"ip"` + // The server id of SRS, store in file, may not change, mandatory. + ServerID *string `json:"server"` + // The service id of SRS, always change when restarted, mandatory. + ServiceID *string `json:"service"` + // The process id of SRS, always change when restarted, mandatory. + PID *string `json:"pid"` + // The RTMP listen endpoints, mandatory. + RTMP *[]string `json:"rtmp"` + // The HTTP Stream listen endpoints, optional. + HTTP *[]string `json:"http"` + // The API listen endpoints, optional. + API *[]string `json:"api"` + // The SRT listen endpoints, optional. + SRT *[]string `json:"srt"` + // The RTC listen endpoints, optional. + RTC *[]string `json:"rtc"` + // The device id of SRS, optional. + DeviceID *string `json:"device_id"` + }{ + IP: &ip, DeviceID: &deviceID, + ServerID: &serverID, ServiceID: &serviceID, PID: &pid, + RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc, + }); err != nil { + return errors.Wrapf(err, "parse body") + } + + if ip == "" { + return errors.Errorf("empty ip") + } + if serverID == "" { + return errors.Errorf("empty server") + } + if serviceID == "" { + return errors.Errorf("empty service") + } + if pid == "" { + return errors.Errorf("empty pid") + } + if len(rtmp) == 0 { + return errors.Errorf("empty rtmp") + } + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP, srs.DeviceID = ip, deviceID + srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid + srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api + srs.SRT, srs.RTC = srt, rtc + srs.UpdatedAt = time.Now() + }) + if err := srsLoadBalancer.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update SRS server %+v", server) + } + + logger.Df(ctx, "Register SRS media server, %+v", server) + return nil + }(); err != nil { + apiError(ctx, w, r, err) + } + + type Response struct { + Code int `json:"code"` + PID string `json:"pid"` + } + + apiResponse(ctx, w, r, &Response{ + Code: 0, PID: fmt.Sprintf("%v", os.Getpid()), + }) + }) + + // Run System API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If System API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "System API accept err %+v", err) + } else { + logger.Df(ctx, "System API server done") + } + } + }() + + return nil +} diff --git a/proxy/debug.go b/proxy/debug.go new file mode 100644 index 0000000000..3a389b8bbd --- /dev/null +++ b/proxy/debug.go @@ -0,0 +1,20 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "net/http" + + "srs-proxy/logger" +) + +func handleGoPprof(ctx context.Context) { + if addr := envGoPprof(); addr != "" { + go func() { + logger.Df(ctx, "Start Go pprof at %v", addr) + http.ListenAndServe(addr, nil) + }() + } +} diff --git a/proxy/env.go b/proxy/env.go new file mode 100644 index 0000000000..0c201bb1d6 --- /dev/null +++ b/proxy/env.go @@ -0,0 +1,197 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "path" + + "github.com/joho/godotenv" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// loadEnvFile loads the environment variables from file. Note that we only use .env file. +func loadEnvFile(ctx context.Context) error { + if workDir, err := os.Getwd(); err != nil { + return errors.Wrapf(err, "getpwd") + } else { + envFile := path.Join(workDir, ".env") + if _, err := os.Stat(envFile); err == nil { + if err := godotenv.Load(envFile); err != nil { + return errors.Wrapf(err, "load %v", envFile) + } + } + } + + return nil +} + +// buildDefaultEnvironmentVariables setups the default environment variables. +func buildDefaultEnvironmentVariables(ctx context.Context) { + // Whether enable the Go pprof. + setEnvDefault("GO_PPROF", "") + // Force shutdown timeout. + setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s") + // Graceful quit timeout. + setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") + + // The HTTP API server. + setEnvDefault("PROXY_HTTP_API", "11985") + // The HTTP web server. + setEnvDefault("PROXY_HTTP_SERVER", "18080") + // The RTMP media server. + setEnvDefault("PROXY_RTMP_SERVER", "11935") + // The WebRTC media server, via UDP protocol. + setEnvDefault("PROXY_WEBRTC_SERVER", "18000") + // The SRT media server, via UDP protocol. + setEnvDefault("PROXY_SRT_SERVER", "20080") + // The API server of proxy itself. + setEnvDefault("PROXY_SYSTEM_API", "12025") + // The static directory for web server. + setEnvDefault("PROXY_STATIC_FILES", "../trunk/research") + + // The load balancer, use redis or memory. + setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory") + // The redis server host. + setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1") + // The redis server port. + setEnvDefault("PROXY_REDIS_PORT", "6379") + // The redis server password. + setEnvDefault("PROXY_REDIS_PASSWORD", "") + // The redis server db. + setEnvDefault("PROXY_REDIS_DB", "0") + + // Whether enable the default backend server, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off") + // Default backend server IP, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") + // Default backend server port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935") + // Default backend api port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985") + // Default backend udp rtc port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000") + // Default backend udp srt port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080") + + logger.Df(ctx, "load .env as GO_PPROF=%v, "+ + "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ + "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ + "PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+ + "PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ + "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ + "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ + "PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+ + "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ + "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", + envGoPprof(), + envForceQuitTimeout(), envGraceQuitTimeout(), + envHttpAPI(), envHttpServer(), envRtmpServer(), + envWebRTCServer(), envSRTServer(), + envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(), + envDefaultBackendIP(), envDefaultBackendRTMP(), + envDefaultBackendHttp(), envDefaultBackendAPI(), + envDefaultBackendRTC(), envDefaultBackendSRT(), + envLoadBalancerType(), envRedisHost(), envRedisPort(), + envRedisPassword(), envRedisDB(), + ) +} + +func envStaticFiles() string { + return os.Getenv("PROXY_STATIC_FILES") +} + +func envDefaultBackendSRT() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_SRT") +} + +func envDefaultBackendRTC() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTC") +} + +func envDefaultBackendAPI() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_API") +} + +func envSRTServer() string { + return os.Getenv("PROXY_SRT_SERVER") +} + +func envWebRTCServer() string { + return os.Getenv("PROXY_WEBRTC_SERVER") +} + +func envDefaultBackendHttp() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP") +} + +func envRedisDB() string { + return os.Getenv("PROXY_REDIS_DB") +} + +func envRedisPassword() string { + return os.Getenv("PROXY_REDIS_PASSWORD") +} + +func envRedisPort() string { + return os.Getenv("PROXY_REDIS_PORT") +} + +func envRedisHost() string { + return os.Getenv("PROXY_REDIS_HOST") +} + +func envLoadBalancerType() string { + return os.Getenv("PROXY_LOAD_BALANCER_TYPE") +} + +func envDefaultBackendRTMP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP") +} + +func envDefaultBackendIP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_IP") +} + +func envDefaultBackendEnabled() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED") +} + +func envGraceQuitTimeout() string { + return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") +} + +func envForceQuitTimeout() string { + return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT") +} + +func envGoPprof() string { + return os.Getenv("GO_PPROF") +} + +func envSystemAPI() string { + return os.Getenv("PROXY_SYSTEM_API") +} + +func envRtmpServer() string { + return os.Getenv("PROXY_RTMP_SERVER") +} + +func envHttpServer() string { + return os.Getenv("PROXY_HTTP_SERVER") +} + +func envHttpAPI() string { + return os.Getenv("PROXY_HTTP_API") +} + +// setEnvDefault set env key=value if not set. +func setEnvDefault(key, value string) { + if os.Getenv(key) == "" { + os.Setenv(key, value) + } +} diff --git a/proxy/errors/errors.go b/proxy/errors/errors.go new file mode 100644 index 0000000000..257bc3ccda --- /dev/null +++ b/proxy/errors/errors.go @@ -0,0 +1,270 @@ +// Package errors provides simple error handling primitives. +// +// The traditional error handling idiom in Go is roughly akin to +// +// if err != nil { +// return err +// } +// +// which applied recursively up the call stack results in error reports +// without context or debugging information. The errors package allows +// programmers to add context to the failure path in their code in a way +// that does not destroy the original value of the error. +// +// Adding context to an error +// +// The errors.Wrap function returns a new error that adds context to the +// original error by recording a stack trace at the point Wrap is called, +// and the supplied message. For example +// +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Wrap(err, "read failed") +// } +// +// If additional control is required the errors.WithStack and errors.WithMessage +// functions destructure errors.Wrap into its component operations of annotating +// an error with a stack trace and an a message, respectively. +// +// Retrieving the cause of an error +// +// Using errors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// the topmost error which does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// causer interface is not exported by this package, but is considered a part +// of stable public API. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported +// +// %s print the error. If the error has a Cause it will be +// printed recursively +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Retrieving the stack trace of an error or wrapper +// +// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are +// invoked. This information can be retrieved with the following interface. +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// Where errors.StackTrace is defined as +// +// type StackTrace []Frame +// +// The Frame type represents a call site in the stack trace. Frame supports +// the fmt.Formatter interface that can be used for printing information about +// the stack trace of this error. For example: +// +// if err, ok := err.(stackTracer); ok { +// for _, f := range err.StackTrace() { +// fmt.Printf("%+s:%d", f) +// } +// } +// +// stackTracer interface is not exported by this package, but is considered a part +// of stable public API. +// +// See the documentation for Frame.Format for more details. +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" +) + +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(message string) error { + return &fundamental{ + msg: message, + stack: callers(), + } +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + *stack +} + +func (f *fundamental) Error() string { return f.msg } + +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, f.msg) + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } +} + +// WithStack annotates err with a stack trace at the point WithStack was called. +// If err is nil, WithStack returns nil. +func WithStack(err error) error { + if err == nil { + return nil + } + return &withStack{ + err, + callers(), + } +} + +type withStack struct { + error + *stack +} + +func (w *withStack) Cause() error { return w.error } + +func (w *withStack) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: message, + } + return &withStack{ + err, + callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is call, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } + return &withStack{ + err, + callers(), + } +} + +// WithMessage annotates err with a new message. +// If err is nil, WithMessage returns nil. +func WithMessage(err error, message string) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: message, + } +} + +type withMessage struct { + cause error + msg string +} + +func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *withMessage) Cause() error { return w.cause } + +func (w *withMessage) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// Cause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/proxy/errors/stack.go b/proxy/errors/stack.go new file mode 100644 index 0000000000..6c42db5a85 --- /dev/null +++ b/proxy/errors/stack.go @@ -0,0 +1,187 @@ +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" + "path" + "runtime" + "strings" +) + +// Frame represents a program counter inside a stack frame. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s path of source file relative to the compile time GOPATH +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + pc := f.pc() + fn := runtime.FuncForPC(pc) + if fn == nil { + io.WriteString(s, "unknown") + } else { + file, _ := fn.FileLine(pc) + fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + } + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + fmt.Fprintf(s, "%d", f.line()) + case 'n': + name := runtime.FuncForPC(f.pc()).Name() + io.WriteString(s, funcname(name)) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +// Format formats the stack of Frames according to the fmt.Formatter interface. +// +// %s lists source files for each Frame in the stack +// %v lists the source file and line number for each Frame in the stack +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+v Prints filename, function, and line number for each Frame in the stack. +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + fmt.Fprintf(s, "\n%+v", f) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + fmt.Fprintf(s, "%v", []Frame(st)) + } + case 's': + fmt.Fprintf(s, "%s", []Frame(st)) + } +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} + +func trimGOPATH(name, file string) string { + // Here we want to get the source file path relative to the compile time + // GOPATH. As of Go 1.6.x there is no direct way to know the compiled + // GOPATH at runtime, but we can infer the number of path segments in the + // GOPATH. We note that fn.Name() returns the function name qualified by + // the import path, which does not include the GOPATH. Thus we can trim + // segments from the beginning of the file path until the number of path + // separators remaining is one more than the number of path separators in + // the function name. For example, given: + // + // GOPATH /home/user + // file /home/user/src/pkg/sub/file.go + // fn.Name() pkg/sub.Type.Method + // + // We want to produce: + // + // pkg/sub/file.go + // + // From this we can easily see that fn.Name() has one less path separator + // than our desired output. We count separators from the end of the file + // path until it finds two more than in the function name and then move + // one character forward to preserve the initial path segment without a + // leading separator. + const sep = "/" + goal := strings.Count(name, sep) + 2 + i := len(file) + for n := 0; n < goal; n++ { + i = strings.LastIndex(file[:i], sep) + if i == -1 { + // not enough separators found, set i so that the slice expression + // below leaves file unmodified + i = -len(sep) + break + } + } + // get back to 0 or trim the leading separator + file = file[i+len(sep):] + return file +} diff --git a/proxy/go.mod b/proxy/go.mod new file mode 100644 index 0000000000..2e2a17ab34 --- /dev/null +++ b/proxy/go.mod @@ -0,0 +1,13 @@ +module srs-proxy + +go 1.18 + +require ( + github.com/go-redis/redis/v8 v8.11.5 + github.com/joho/godotenv v1.5.1 +) + +require ( + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/proxy/go.sum b/proxy/go.sum new file mode 100644 index 0000000000..1efc5318ed --- /dev/null +++ b/proxy/go.sum @@ -0,0 +1,17 @@ +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/proxy/http.go b/proxy/http.go new file mode 100644 index 0000000000..f02af02a30 --- /dev/null +++ b/proxy/http.go @@ -0,0 +1,419 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "strconv" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, +// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy +// the request to the origin server. +type srsHTTPStreamServer struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg stdSync.WaitGroup +} + +func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer { + v := &srsHTTPStreamServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsHTTPStreamServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *srsHTTPStreamServer) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpServer() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP Stream server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + type Response struct { + Code int `json:"code"` + PID string `json:"pid"` + Data struct { + Major int `json:"major"` + Minor int `json:"minor"` + Revision int `json:"revision"` + Version string `json:"version"` + } `json:"data"` + } + + res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())} + res.Data.Major = VersionMajor() + res.Data.Minor = VersionMinor() + res.Data.Revision = VersionRevision() + res.Data.Version = Version() + + apiResponse(ctx, w, r, &res) + }) + + // The static web server, for the web pages. + var staticServer http.Handler + if staticFiles := envStaticFiles(); staticFiles != "" { + if _, err := os.Stat(staticFiles); err != nil { + return errors.Wrapf(err, "invalid static files %v", staticFiles) + } + + staticServer = http.FileServer(http.Dir(staticFiles)) + logger.Df(ctx, "Handle static files at %v", staticFiles) + } + + // The default handler, for both static web server and streaming server. + logger.Df(ctx, "Handle / by %v", addr) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // For HLS streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".m3u8") { + unifiedURL, fullURL := convertURLToStreamURL(r) + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest) + return + } + + stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { + s.SRSProxyBackendHLSID = logger.GenerateContextID() + s.StreamURL, s.FullURL = streamURL, fullURL + })) + + stream.Initialize(ctx).ServeHTTP(w, r) + return + } + + // For HTTP streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".flv") || + strings.HasSuffix(r.URL.Path, ".ts") { + // If SPBHID is specified, it must be a HLS stream client. + if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" { + if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { + http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) + } else { + stream.Initialize(ctx).ServeHTTP(w, r) + } + return + } + + // Use HTTP pseudo streaming to proxy the request. + NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) { + c.ctx = ctx + }).ServeHTTP(w, r) + return + } + + // Serve by static server. + if staticServer != nil { + staticServer.ServeHTTP(w, r) + return + } + + http.NotFound(w, r) + }) + + // Run HTTP server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP Stream accept err %+v", err) + } else { + logger.Df(ctx, "HTTP Stream server done") + } + } + }() + + return nil +} + +// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS +// connection. There is no state need to be sync between proxy servers. +// +// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request, +// then proxy to the corresponding backend server. All state is in the HTTP request, so this +// connection is stateless. +type HTTPFlvTsConnection struct { + // The context for HTTP streaming. + ctx context.Context +} + +func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection { + v := &HTTPFlvTsConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + ctx := logger.WithContext(v.ctx) + + if err := v.serve(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } else { + logger.Df(ctx, "HTTP client done") + } +} + +func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no http stream server") + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Wrapf(err, "do request to %v", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + logger.Df(ctx, "HTTP start streaming") + + // Proxy the stream from backend to client. + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) + } + + return nil +} + +// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS +// clients will share this object, and they do not use the same ctx among proxy servers. +// +// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. +// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create +// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert +// to the stream URL and then query the backend server to serve it. +type HLSPlayStream struct { + // The context for HLS streaming. + ctx context.Context + + // The spbhid, used to identify the backend server. + SRSProxyBackendHLSID string `json:"spbhid"` + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + // The full request URL for HLS streaming + FullURL string `json:"full_url"` +} + +func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { + v := &HLSPlayStream{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } + return v +} + +func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if err := v.serve(v.ctx, w, r); err != nil { + apiError(v.ctx, w, r, err) + } else { + logger.Df(v.ctx, "HLS client %v for %v with %v done", + v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path) + } +} + +func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no rtmp server") + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Errorf("do request to %v EOF", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // For TS file, directly copy it. + if !strings.HasSuffix(r.URL.Path, ".m3u8") { + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) + } + + return nil + } + + // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts + // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + m3u8 := string(b) + if strings.Contains(m3u8, ".ts?") { + m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) + } else { + m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) + } + + if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { + return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL) + } + + return nil +} diff --git a/proxy/logger/context.go b/proxy/logger/context.go new file mode 100644 index 0000000000..ef15a7d4fb --- /dev/null +++ b/proxy/logger/context.go @@ -0,0 +1,43 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" +) + +type key string + +var cidKey key = "cid.proxy.ossrs.org" + +// generateContextID generates a random context id in string. +func GenerateContextID() string { + randomBytes := make([]byte, 32) + _, _ = rand.Read(randomBytes) + hash := sha256.Sum256(randomBytes) + hashString := hex.EncodeToString(hash[:]) + cid := hashString[:7] + return cid +} + +// WithContext creates a new context with cid, which will be used for log. +func WithContext(ctx context.Context) context.Context { + return WithContextID(ctx, GenerateContextID()) +} + +// WithContextID creates a new context with cid, which will be used for log. +func WithContextID(ctx context.Context, cid string) context.Context { + return context.WithValue(ctx, cidKey, cid) +} + +// ContextID returns the cid in context, or empty string if not set. +func ContextID(ctx context.Context) string { + if cid, ok := ctx.Value(cidKey).(string); ok { + return cid + } + return "" +} diff --git a/proxy/logger/log.go b/proxy/logger/log.go new file mode 100644 index 0000000000..debbe1a847 --- /dev/null +++ b/proxy/logger/log.go @@ -0,0 +1,87 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "io/ioutil" + stdLog "log" + "os" +) + +type logger interface { + Printf(ctx context.Context, format string, v ...any) +} + +type loggerPlus struct { + logger *stdLog.Logger + level string +} + +func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { + v := &loggerPlus{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { + format, args := f, a + if cid := ContextID(ctx); cid != "" { + format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) + } + + v.logger.Printf(format, args...) +} + +var verboseLogger logger + +func Vf(ctx context.Context, format string, a ...interface{}) { + verboseLogger.Printf(ctx, format, a...) +} + +var debugLogger logger + +func Df(ctx context.Context, format string, a ...interface{}) { + debugLogger.Printf(ctx, format, a...) +} + +var warnLogger logger + +func Wf(ctx context.Context, format string, a ...interface{}) { + warnLogger.Printf(ctx, format, a...) +} + +var errorLogger logger + +func Ef(ctx context.Context, format string, a ...interface{}) { + errorLogger.Printf(ctx, format, a...) +} + +const ( + logVerboseLabel = "verb" + logDebugLabel = "debug" + logWarnLabel = "warn" + logErrorLabel = "error" +) + +func init() { + verboseLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logVerboseLabel + }) + debugLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logDebugLabel + }) + warnLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logWarnLabel + }) + errorLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logErrorLabel + }) +} diff --git a/proxy/main.go b/proxy/main.go new file mode 100644 index 0000000000..6327a7cf80 --- /dev/null +++ b/proxy/main.go @@ -0,0 +1,121 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func main() { + ctx := logger.WithContext(context.Background()) + logger.Df(ctx, "%v/%v started", Signature(), Version()) + + // Install signals. + ctx, cancel := context.WithCancel(ctx) + installSignals(ctx, cancel) + + // Start the main loop, ignore the user cancel error. + err := doMain(ctx) + if err != nil && ctx.Err() != context.Canceled { + logger.Ef(ctx, "main: %+v", err) + os.Exit(-1) + } + + logger.Df(ctx, "%v done", Signature()) +} + +func doMain(ctx context.Context) error { + // Setup the environment variables. + if err := loadEnvFile(ctx); err != nil { + return errors.Wrapf(err, "load env") + } + + buildDefaultEnvironmentVariables(ctx) + + // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur + // because the main thread exits after the context is cancelled. However, sometimes the main thread + // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. + if err := installForceQuit(ctx); err != nil { + return errors.Wrapf(err, "install force quit") + } + + // Start the Go pprof if enabled. + handleGoPprof(ctx) + + // Initialize SRS load balancers. + switch lbType := envLoadBalancerType(); lbType { + case "memory": + srsLoadBalancer = NewMemoryLoadBalancer() + case "redis": + srsLoadBalancer = NewRedisLoadBalancer() + default: + return errors.Errorf("invalid load balancer %v", lbType) + } + + if err := srsLoadBalancer.Initialize(ctx); err != nil { + return errors.Wrapf(err, "initialize srs load balancer") + } + + // Parse the gracefully quit timeout. + gracefulQuitTimeout, err := parseGracefullyQuitTimeout() + if err != nil { + return errors.Wrapf(err, "parse gracefully quit timeout") + } + + // Start the RTMP server. + srsRTMPServer := NewSRSRTMPServer() + defer srsRTMPServer.Close() + if err := srsRTMPServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtmp server") + } + + // Start the WebRTC server. + srsWebRTCServer := NewSRSWebRTCServer() + defer srsWebRTCServer.Close() + if err := srsWebRTCServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtc server") + } + + // Start the HTTP API server. + srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) { + server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer + }) + defer srsHTTPAPIServer.Close() + if err := srsHTTPAPIServer.Run(ctx); err != nil { + return errors.Wrapf(err, "http api server") + } + + // Start the SRT server. + srsSRTServer := NewSRSSRTServer() + defer srsSRTServer.Close() + if err := srsSRTServer.Run(ctx); err != nil { + return errors.Wrapf(err, "srt server") + } + + // Start the System API server. + systemAPI := NewSystemAPI(func(server *systemAPI) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer systemAPI.Close() + if err := systemAPI.Run(ctx); err != nil { + return errors.Wrapf(err, "system api server") + } + + // Start the HTTP web server. + srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer srsHTTPStreamServer.Close() + if err := srsHTTPStreamServer.Run(ctx); err != nil { + return errors.Wrapf(err, "http server") + } + + // Wait for the main loop to quit. + <-ctx.Done() + return nil +} diff --git a/proxy/rtc.go b/proxy/rtc.go new file mode 100644 index 0000000000..5a7d9936c7 --- /dev/null +++ b/proxy/rtc.go @@ -0,0 +1,515 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/binary" + "fmt" + "io/ioutil" + "net" + "net/http" + "strconv" + "strings" + stdSync "sync" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out +// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the +// SDP answer. +type srsWebRTCServer struct { + // The UDP listener for WebRTC server. + listener *net.UDPConn + + // Fast cache for the username to identify the connection. + // The key is username, the value is the UDP address. + usernames sync.Map[string, *RTCConnection] + // Fast cache for the udp address to identify the connection. + // The key is UDP address, the value is the username. + // TODO: Support fast earch by uint64 address. + addresses sync.Map[string, *RTCConnection] + + // The wait group for server. + wg stdSync.WaitGroup +} + +func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer { + v := &srsWebRTCServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsWebRTCServer) Close() error { + if v.listener != nil { + _ = v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *srsWebRTCServer) proxyApiToBackend( + ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, + remoteSDPOffer string, streamURL string, +) error { + // Parse HTTP port from backend. + if len(backend.API) == 0 { + return errors.Errorf("no http api server") + } + + var apiPort int + if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.API[0]) + } else { + apiPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer)) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Errorf("do request to %v EOF", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // Parse the local SDP answer from backend. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + // Replace the WebRTC UDP port in answer. + localSDPAnswer := string(b) + for _, endpoint := range backend.RTC { + _, _, port, err := parseListenEndpoint(endpoint) + if err != nil { + return errors.Wrapf(err, "parse endpoint %v", endpoint) + } + + from := fmt.Sprintf(" %v typ host", port) + to := fmt.Sprintf(" %v typ host", envWebRTCServer()) + localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1) + } + + // Fetch the ice-ufrag and ice-pwd from local SDP answer. + remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer) + if err != nil { + return errors.Wrapf(err, "parse remote sdp offer") + } + + localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer) + if err != nil { + return errors.Wrapf(err, "parse local sdp answer") + } + + // Save the new WebRTC connection to LB. + icePair := &RTCICEPair{ + RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, + LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, + } + if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) { + c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() + c.Initialize(ctx, v.listener) + + // Cache the connection for fast search by username. + v.usernames.Store(c.Ufrag, c) + })); err != nil { + return errors.Wrapf(err, "load or store webrtc %v", streamURL) + } + + // Response client with local answer. + if _, err = w.Write([]byte(localSDPAnswer)); err != nil { + return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer) + } + + logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB", + len(localSDPAnswer), localICEUfrag, len(localICEPwd)) + return nil +} + +func (v *srsWebRTCServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envWebRTCServer() + if !strings.Contains(endpoint, ":") { + endpoint = fmt.Sprintf(":%v", endpoint) + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "WebRTC server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := listener.ReadFromUDP(buf) + if err != nil { + // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + var connection *RTCConnection + + // If STUN binding request, parse the ufrag and identify the connection. + if err := func() error { + if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) { + return nil + } + + var pkt RTCStunPacket + if err := pkt.UnmarshalBinary(data); err != nil { + return errors.Wrapf(err, "unmarshal stun packet") + } + + // Search the connection in fast cache. + if s, ok := v.usernames.Load(pkt.Username); ok { + connection = s + return nil + } + + // Load connection by username. + if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { + return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) + } else { + connection = s.Initialize(ctx, v.listener) + logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) + } + + // Cache connection for fast search. + if connection != nil { + v.usernames.Store(pkt.Username, connection) + } + return nil + }(); err != nil { + return err + } + + // Search the connection by addr. + if s, ok := v.addresses.Load(addr.String()); ok { + connection = s + } else if connection != nil { + // Cache the address for fast search. + v.addresses.Store(addr.String(), connection) + } + + // If connection is not found, ignore the packet. + if connection == nil { + // TODO: Should logging the dropped packet, only logging the first one for each address. + return nil + } + + // Proxy the packet to backend. + if err := connection.HandlePacket(addr, data); err != nil { + return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL) + } + + return nil +} + +// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC +// connection, identify by the ufrag in sdp offer/answer and ICE binding request. +// +// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is +// in the client request. The RTCConnection is stateful, and need to sync the ufrag between +// proxy servers. +// +// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch +// to another UDP address, it may connect to another WebRTC proxy, then we should discover the +// RTCConnection by the ufrag from the ICE binding request. +type RTCConnection struct { + // The stream context for WebRTC streaming. + ctx context.Context + + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + // The ufrag for this WebRTC connection. + Ufrag string `json:"ufrag"` + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The client UDP address. Note that it may change. + clientUDP *net.UDPAddr + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn +} + +func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { + v := &RTCConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } + if listener != nil { + v.listenerUDP = listener + } + return v +} + +func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { + ctx := v.ctx + + // Update the current UDP address. + v.clientUDP = addr + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx); err != nil { + return errors.Wrapf(err, "connect backend for %v", v.StreamURL) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return nil + } + + // Proxy all messages from backend to client. + go func() { + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, _, err := v.backendUDP.ReadFromUDP(buf) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + break + } + + if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + break + } + } + }() + + if _, err := v.backendUDP.Write(data); err != nil { + return errors.Wrapf(err, "write to backend %v", v.StreamURL) + } + + return nil +} + +func (v *RTCConnection) connectBackend(ctx context.Context) error { + if v.backendUDP != nil { + return nil + } + + // Pick a backend SRS server to proxy the RTC stream. + backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL) + if err != nil { + return errors.Wrapf(err, "pick backend") + } + + // Parse UDP port from backend. + if len(backend.RTC) == 0 { + return errors.Errorf("no udp server") + } + + _, _, udpPort, err := parseListenEndpoint(backend.RTC[0]) + if err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or DTLS alert. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v", backendAddr) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +type RTCICEPair struct { + // The remote ufrag, used for ICE username and session id. + RemoteICEUfrag string `json:"remote_ufrag"` + // The remote pwd, used for ICE password. + RemoteICEPwd string `json:"remote_pwd"` + // The local ufrag, used for ICE username and session id. + LocalICEUfrag string `json:"local_ufrag"` + // The local pwd, used for ICE password. + LocalICEPwd string `json:"local_pwd"` +} + +// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag. +func (v *RTCICEPair) Ufrag() string { + return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag) +} + +type RTCStunPacket struct { + // The stun message type. + MessageType uint16 + // The stun username, or ufrag. + Username string +} + +func (v *RTCStunPacket) UnmarshalBinary(data []byte) error { + if len(data) < 20 { + return errors.Errorf("stun packet too short %v", len(data)) + } + + p := data + v.MessageType = binary.BigEndian.Uint16(p) + messageLen := binary.BigEndian.Uint16(p[2:]) + //magicCookie := p[:8] + //transactionID := p[:20] + p = p[20:] + + if len(p) != int(messageLen) { + return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen) + } + + for len(p) > 0 { + typ := binary.BigEndian.Uint16(p) + length := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(length) { + return errors.Errorf("stun attribute length invalid %v < %v", len(p), length) + } + + value := p[:length] + p = p[length:] + + if length%4 != 0 { + p = p[4-length%4:] + } + + switch typ { + case 0x0006: + v.Username = string(value) + } + } + + return nil +} diff --git a/proxy/rtmp.go b/proxy/rtmp.go new file mode 100644 index 0000000000..d93f04b3a6 --- /dev/null +++ b/proxy/rtmp.go @@ -0,0 +1,655 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/rtmp" +) + +// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS +// server. It will figure out the backend server to proxy to. Unlike the edge server, it will +// not cache the stream, but just proxy the stream to backend. +type srsRTMPServer struct { + // The TCP listener for RTMP server. + listener *net.TCPListener + // The random number generator. + rd *rand.Rand + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer { + v := &srsRTMPServer{ + rd: rand.New(rand.NewSource(time.Now().UnixNano())), + } + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsRTMPServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsRTMPServer) Run(ctx context.Context) error { + endpoint := envRtmpServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + addr, err := net.ResolveTCPAddr("tcp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve rtmp addr %v", endpoint) + } + + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return errors.Wrapf(err, "listen rtmp addr %v", addr) + } + v.listener = listener + logger.Df(ctx, "RTMP server listen at %v", addr) + + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for { + conn, err := v.listener.AcceptTCP() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "RTMP server accept err %+v", err) + } else { + logger.Df(ctx, "RTMP server done") + } + return + } + + v.wg.Add(1) + go func(ctx context.Context, conn *net.TCPConn) { + defer v.wg.Done() + defer conn.Close() + + handleErr := func(err error) { + if isPeerClosedError(err) { + logger.Df(ctx, "RTMP peer is closed") + } else { + logger.Wf(ctx, "RTMP serve err %+v", err) + } + } + + rc := NewRTMPConnection(func(client *RTMPConnection) { + client.rd = v.rd + }) + if err := rc.serve(ctx, conn); err != nil { + handleErr(err) + } else { + logger.Df(ctx, "RTMP client done") + } + }(logger.WithContext(ctx), conn) + } + }() + + return nil +} + +// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between +// proxy servers. +// +// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request, +// then proxy to the corresponding backend server. All state is in the RTMP request, so this +// connection is stateless. +type RTMPConnection struct { + // The random number generator. + rd *rand.Rand +} + +func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection { + v := &RTMPConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { + logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) + + // If any goroutine quit, cancel another one. + parentCtx := ctx + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var backend *RTMPClientToBackend + if true { + go func() { + <-ctx.Done() + conn.Close() + if backend != nil { + backend.Close() + } + }() + } + + // Simple handshake with client. + hs := rtmp.NewHandshake(v.rd) + if _, err := hs.ReadC0S0(conn); err != nil { + return errors.Wrapf(err, "read c0") + } + if _, err := hs.ReadC1S1(conn); err != nil { + return errors.Wrapf(err, "read c1") + } + if err := hs.WriteC0S0(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC1S1(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write s2") + } + if _, err := hs.ReadC2S2(conn); err != nil { + return errors.Wrapf(err, "read c2") + } + + client := rtmp.NewProtocol(conn) + logger.Df(ctx, "RTMP simple handshake done") + + // Expect RTMP connect command with tcUrl. + var connectReq *rtmp.ConnectAppPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil { + return errors.Wrapf(err, "expect connect req") + } + + if true { + ack := rtmp.NewWindowAcknowledgementSize() + ack.AckSize = 2500000 + if err := client.WritePacket(ctx, ack, 0); err != nil { + return errors.Wrapf(err, "write set ack size") + } + } + if true { + chunk := rtmp.NewSetChunkSize() + chunk.ChunkSize = 128 + if err := client.WritePacket(ctx, chunk, 0); err != nil { + return errors.Wrapf(err, "write set chunk size") + } + } + + connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID) + connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888")) + connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127)) + connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1)) + connectRes.Args.Set("level", rtmp.NewAmf0String("status")) + connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success")) + connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded")) + connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0)) + connectResData := rtmp.NewAmf0EcmaArray() + connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888")) + connectResData.Set("srs_version", rtmp.NewAmf0String(Version())) + connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx))) + connectRes.Args.Set("data", connectResData) + if err := client.WritePacket(ctx, connectRes, 0); err != nil { + return errors.Wrapf(err, "write connect res") + } + + tcUrl := connectReq.TcUrl() + logger.Df(ctx, "RTMP connect app %v", tcUrl) + + // Expect RTMP command to identify the client, a publisher or viewer. + var currentStreamID, nextStreamID int + var streamName string + var clientType RTMPClientType + for clientType == "" { + var identifyReq rtmp.Packet + if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil { + return errors.Wrapf(err, "expect identify req") + } + + var response rtmp.Packet + switch pkt := identifyReq.(type) { + case *rtmp.CallPacket: + if pkt.CommandName == "createStream" { + identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) + response = identifyRes + + nextStreamID = 1 + identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) + } else if pkt.CommandName == "getStreamLength" { + // Ignore and do not reply these packets. + } else { + // For releaseStream, FCPublish, etc. + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.TransactionID = pkt.TransactionID + identifyRes.CommandName = "_result" + identifyRes.CommandObject = rtmp.NewAmf0Null() + identifyRes.Args = rtmp.NewAmf0Undefined() + } + case *rtmp.PublishPacket: + streamName = string(pkt.StreamName) + clientType = RTMPClientTypePublisher + + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.CommandName = "onFCPublish" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + identifyRes.Args = data + case *rtmp.PlayPacket: + streamName = string(pkt.StreamName) + clientType = RTMPClientTypeViewer + + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset")) + data.Set("description", rtmp.NewAmf0String("Playing and resetting stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + } + + if response != nil { + if err := client.WritePacket(ctx, response, currentStreamID); err != nil { + return errors.Wrapf(err, "write identify res for req=%v, stream=%v", + identifyReq, currentStreamID) + } + } + + // Update the stream ID for next request. + currentStreamID = nextStreamID + } + logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v", + tcUrl, streamName, currentStreamID, clientType) + + // Find a backend SRS server to proxy the RTMP stream. + backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) { + client.rd, client.typ = v.rd, clientType + }) + defer backend.Close() + + if err := backend.Connect(ctx, tcUrl, streamName); err != nil { + return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName) + } + + // Start the streaming. + if clientType == RTMPClientTypePublisher { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start publish") + } + } else if clientType == RTMPClientTypeViewer { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start")) + data.Set("description", rtmp.NewAmf0String("Started playing stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start play") + } + } + logger.Df(ctx, "RTMP start streaming") + + // For all proxy goroutines. + var wg sync.WaitGroup + defer wg.Wait() + + // Proxy all message from backend to client. + wg.Add(1) + var r0 error + go func() { + defer wg.Done() + defer cancel() + + r0 = func() error { + for { + m, err := backend.client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } + } + }() + }() + + // Proxy all messages from client to backend. + wg.Add(1) + var r1 error + go func() { + defer wg.Done() + defer cancel() + + r1 = func() error { + for { + m, err := client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := backend.client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } + } + }() + }() + + // Wait until all goroutine quit. + wg.Wait() + + // Reset the error if caused by another goroutine. + if r0 != nil { + return errors.Wrapf(r0, "proxy backend->client") + } + if r1 != nil { + return errors.Wrapf(r1, "proxy client->backend") + } + + return parentCtx.Err() +} + +type RTMPClientType string + +const ( + RTMPClientTypePublisher RTMPClientType = "publisher" + RTMPClientTypeViewer RTMPClientType = "viewer" +) + +// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend. +type RTMPClientToBackend struct { + // The random number generator. + rd *rand.Rand + // The underlayer tcp client. + tcpConn *net.TCPConn + // The RTMP protocol client. + client *rtmp.Protocol + // The stream type. + typ RTMPClientType +} + +func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend { + v := &RTMPClientToBackend{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPClientToBackend) Close() error { + if v.tcpConn != nil { + v.tcpConn.Close() + } + return nil +} + +func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error { + // Build the stream URL in vhost/app/stream schema. + streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName)) + if err != nil { + return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse RTMP port from backend. + if len(backend.RTMP) == 0 { + return errors.Errorf("no rtmp server %+v for %v", backend, streamURL) + } + + var rtmpPort int + if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0]) + } else { + rtmpPort = int(iv) + } + + // Connect to backend SRS server via TCP client. + addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort} + c, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend) + } + v.tcpConn = c + + hs := rtmp.NewHandshake(v.rd) + client := rtmp.NewProtocol(c) + v.client = client + + // Simple RTMP handshake with server. + if err := hs.WriteC0S0(c); err != nil { + return errors.Wrapf(err, "write c0") + } + if err := hs.WriteC1S1(c); err != nil { + return errors.Wrapf(err, "write c1") + } + + if _, err = hs.ReadC0S0(c); err != nil { + return errors.Wrapf(err, "read s0") + } + if _, err := hs.ReadC1S1(c); err != nil { + return errors.Wrapf(err, "read s1") + } + if _, err = hs.ReadC2S2(c); err != nil { + return errors.Wrapf(err, "read c2") + } + logger.Df(ctx, "backend simple handshake done, server=%v", addr) + + if err := hs.WriteC2S2(c, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write c2") + } + + // Connect RTMP app on tcUrl with server. + if true { + connectApp := rtmp.NewConnectAppPacket() + connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl)) + if err := client.WritePacket(ctx, connectApp, 1); err != nil { + return errors.Wrapf(err, "write connect app") + } + } + + if true { + var connectAppRes *rtmp.ConnectAppResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil { + return errors.Wrapf(err, "expect connect app res") + } + logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID()) + } + + // Play or view RTMP stream with server. + if v.typ == RTMPClientTypeViewer { + return v.play(ctx, client, streamName) + } + + // Publish RTMP stream with server. + return v.publish(ctx, client, streamName) +} + +func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "releaseStream" + identifyReq.TransactionID = 2 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "releaseStream") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "FCPublish" + identifyReq.TransactionID = 3 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "FCPublish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect FCPublish res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + var currentStreamID int + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + if true { + publishStream := rtmp.NewPublishPacket() + publishStream.TransactionID = 5 + publishStream.CommandObject = rtmp.NewAmf0Null() + publishStream.StreamName = *rtmp.NewAmf0String(streamName) + publishStream.StreamType = *rtmp.NewAmf0String("live") + if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil { + return errors.Wrapf(err, "publish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect publish res") + } + // Ignore onFCPublish, expect onStatus(NetStream.Publish.Start). + if identifyRes.CommandName == "onStatus" { + if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil { + return errors.Errorf("onStatus args not object") + } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil { + return errors.Errorf("onStatus code not string") + } else if *code != "NetStream.Publish.Start" { + return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) + } + break + } + } + logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID) + + return nil +} + +func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { + var currentStreamID int + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + playStream := rtmp.NewPlayPacket() + playStream.StreamName = *rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil { + return errors.Wrapf(err, "play") + } + + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" { + break + } + } + return nil +} diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go new file mode 100644 index 0000000000..a013d5eccb --- /dev/null +++ b/proxy/rtmp/amf0.go @@ -0,0 +1,771 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bytes" + "encoding" + "encoding/binary" + "fmt" + "math" + "sync" + + "srs-proxy/errors" +) + +// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview +type amf0Marker uint8 + +const ( + amf0MarkerNumber amf0Marker = iota // 0 + amf0MarkerBoolean // 1 + amf0MarkerString // 2 + amf0MarkerObject // 3 + amf0MarkerMovieClip // 4 + amf0MarkerNull // 5 + amf0MarkerUndefined // 6 + amf0MarkerReference // 7 + amf0MarkerEcmaArray // 8 + amf0MarkerObjectEnd // 9 + amf0MarkerStrictArray // 10 + amf0MarkerDate // 11 + amf0MarkerLongString // 12 + amf0MarkerUnsupported // 13 + amf0MarkerRecordSet // 14 + amf0MarkerXmlDocument // 15 + amf0MarkerTypedObject // 16 + amf0MarkerAvmPlusObject // 17 + + amf0MarkerForbidden amf0Marker = 0xff +) + +func (v amf0Marker) String() string { + switch v { + case amf0MarkerNumber: + return "Amf0Number" + case amf0MarkerBoolean: + return "amf0Boolean" + case amf0MarkerString: + return "Amf0String" + case amf0MarkerObject: + return "Amf0Object" + case amf0MarkerNull: + return "Null" + case amf0MarkerUndefined: + return "Undefined" + case amf0MarkerReference: + return "Reference" + case amf0MarkerEcmaArray: + return "EcmaArray" + case amf0MarkerObjectEnd: + return "ObjectEnd" + case amf0MarkerStrictArray: + return "StrictArray" + case amf0MarkerDate: + return "Date" + case amf0MarkerLongString: + return "LongString" + case amf0MarkerUnsupported: + return "Unsupported" + case amf0MarkerXmlDocument: + return "XmlDocument" + case amf0MarkerTypedObject: + return "TypedObject" + case amf0MarkerAvmPlusObject: + return "AvmPlusObject" + case amf0MarkerMovieClip: + return "MovieClip" + case amf0MarkerRecordSet: + return "RecordSet" + default: + return "Forbidden" + } +} + +// For utest to mock it. +type amf0Buffer interface { + Bytes() []byte + WriteByte(c byte) error + Write(p []byte) (n int, err error) +} + +var createBuffer = func() amf0Buffer { + return &bytes.Buffer{} +} + +// All AMF0 things. +type amf0Any interface { + // Binary marshaler and unmarshaler. + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + // Get the size of bytes to marshal this object. + Size() int + + // Get the Marker of any AMF0 stuff. + amf0Marker() amf0Marker +} + +type amf0Converter struct { + from amf0Any +} + +func NewAmf0Converter(from amf0Any) *amf0Converter { + return &amf0Converter{from: from} +} + +func (v *amf0Converter) ToNumber() *amf0Number { + return amf0AnyTo[*amf0Number](v.from) +} + +func (v *amf0Converter) ToBoolean() *amf0Boolean { + return amf0AnyTo[*amf0Boolean](v.from) +} + +func (v *amf0Converter) ToString() *amf0String { + return amf0AnyTo[*amf0String](v.from) +} + +func (v *amf0Converter) ToObject() *amf0Object { + return amf0AnyTo[*amf0Object](v.from) +} + +func (v *amf0Converter) ToNull() *amf0Null { + return amf0AnyTo[*amf0Null](v.from) +} + +func (v *amf0Converter) ToUndefined() *amf0Undefined { + return amf0AnyTo[*amf0Undefined](v.from) +} + +func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray { + return amf0AnyTo[*amf0EcmaArray](v.from) +} + +func (v *amf0Converter) ToStrictArray() *amf0StrictArray { + return amf0AnyTo[*amf0StrictArray](v.from) +} + +// Convert any to specified object. +func amf0AnyTo[T amf0Any](a amf0Any) T { + var to T + if a != nil { + if v, ok := a.(T); ok { + return v + } + } + return to +} + +// Discovery the amf0 object from the bytes b. +func Amf0Discovery(p []byte) (a amf0Any, err error) { + if len(p) < 1 { + return nil, errors.Errorf("require 1 bytes only %v", len(p)) + } + m := amf0Marker(p[0]) + + switch m { + case amf0MarkerNumber: + return NewAmf0Number(0), nil + case amf0MarkerBoolean: + return NewAmf0Boolean(false), nil + case amf0MarkerString: + return NewAmf0String(""), nil + case amf0MarkerObject: + return NewAmf0Object(), nil + case amf0MarkerNull: + return NewAmf0Null(), nil + case amf0MarkerUndefined: + return NewAmf0Undefined(), nil + case amf0MarkerReference: + case amf0MarkerEcmaArray: + return NewAmf0EcmaArray(), nil + case amf0MarkerObjectEnd: + return &amf0ObjectEOF{}, nil + case amf0MarkerStrictArray: + return NewAmf0StrictArray(), nil + case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument, + amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip, + amf0MarkerRecordSet: + return nil, errors.Errorf("Marker %v is not supported", m) + } + return nil, errors.Errorf("Marker %v is invalid", m) +} + +// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8 +type amf0UTF8 string + +func (v *amf0UTF8) Size() int { + return 2 + len(string(*v)) +} + +func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return errors.Errorf("require 2 bytes only %v", len(p)) + } + size := uint16(p[0])<<8 | uint16(p[1]) + + if p = data[2:]; len(p) < int(size) { + return errors.Errorf("require %v bytes only %v", int(size), len(p)) + } + *v = amf0UTF8(string(p[:size])) + + return +} + +func (v *amf0UTF8) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + + size := uint16(len(string(*v))) + data[0] = byte(size >> 8) + data[1] = byte(size) + + if size > 0 { + copy(data[2:], []byte(*v)) + } + + return +} + +// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type +type amf0Number float64 + +func NewAmf0Number(f float64) *amf0Number { + v := amf0Number(f) + return &v +} + +func (v *amf0Number) amf0Marker() amf0Marker { + return amf0MarkerNumber +} + +func (v *amf0Number) Size() int { + return 1 + 8 +} + +func (v *amf0Number) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 9 { + return errors.Errorf("require 9 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerNumber { + return errors.Errorf("Amf0Number amf0Marker %v is illegal", m) + } + + f := binary.BigEndian.Uint64(p[1:]) + *v = amf0Number(math.Float64frombits(f)) + return +} + +func (v *amf0Number) MarshalBinary() (data []byte, err error) { + data = make([]byte, 9) + data[0] = byte(amf0MarkerNumber) + f := math.Float64bits(float64(*v)) + binary.BigEndian.PutUint64(data[1:], f) + return +} + +// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type +type amf0String string + +func NewAmf0String(s string) *amf0String { + v := amf0String(s) + return &v +} + +func (v *amf0String) amf0Marker() amf0Marker { + return amf0MarkerString +} + +func (v *amf0String) Size() int { + u := amf0UTF8(*v) + return 1 + u.Size() +} + +func (v *amf0String) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerString { + return errors.Errorf("Amf0String amf0Marker %v is illegal", m) + } + + var sv amf0UTF8 + if err = sv.UnmarshalBinary(p[1:]); err != nil { + return errors.WithMessage(err, "utf8") + } + *v = amf0String(string(sv)) + return +} + +func (v *amf0String) MarshalBinary() (data []byte, err error) { + u := amf0UTF8(*v) + + var pb []byte + if pb, err = u.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "utf8") + } + + data = append([]byte{byte(amf0MarkerString)}, pb...) + return +} + +// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type +type amf0ObjectEOF struct { +} + +func (v *amf0ObjectEOF) amf0Marker() amf0Marker { + return amf0MarkerObjectEnd +} + +func (v *amf0ObjectEOF) Size() int { + return 3 +} + +func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) { + p := data + + if len(p) < 3 { + return errors.Errorf("require 3 bytes only %v", len(p)) + } + + if p[0] != 0 || p[1] != 0 || p[2] != 9 { + return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3]) + } + return +} + +func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) { + return []byte{0, 0, 9}, nil +} + +// Use array for object and ecma array, to keep the original order. +type amf0Property struct { + key amf0UTF8 + value amf0Any +} + +// The object-like AMF0 structure, like object and ecma array and strict array. +type amf0ObjectBase struct { + properties []*amf0Property + lock sync.Mutex +} + +func (v *amf0ObjectBase) Size() int { + v.lock.Lock() + defer v.lock.Unlock() + + var size int + + for _, p := range v.properties { + key, value := p.key, p.value + size += key.Size() + value.Size() + } + + return size +} + +func (v *amf0ObjectBase) Get(key string) amf0Any { + v.lock.Lock() + defer v.lock.Unlock() + + for _, p := range v.properties { + if string(p.key) == key { + return p.value + } + } + + return nil +} + +func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { + v.lock.Lock() + defer v.lock.Unlock() + + prop := &amf0Property{key: amf0UTF8(key), value: value} + + var ok bool + for i, p := range v.properties { + if string(p.key) == key { + v.properties[i] = prop + ok = true + } + } + + if !ok { + v.properties = append(v.properties, prop) + } + + return v +} + +func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) { + // if no eof, elems specified by maxElems. + if !eof && maxElems < 0 { + return errors.Errorf("maxElems=%v without eof", maxElems) + } + // if eof, maxElems must be -1. + if eof && maxElems != -1 { + return errors.Errorf("maxElems=%v with eof", maxElems) + } + + readOne := func() (amf0UTF8, amf0Any, error) { + var u amf0UTF8 + if err = u.UnmarshalBinary(p); err != nil { + return "", nil, errors.WithMessage(err, "prop name") + } + + p = p[u.Size():] + var a amf0Any + if a, err = Amf0Discovery(p); err != nil { + return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) + } + return u, a, nil + } + + pushOne := func(u amf0UTF8, a amf0Any) error { + // For object property, consume the whole bytes. + if err = a.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) + } + + v.Set(string(u), a) + p = p[a.Size():] + return nil + } + + for eof { + u, a, err := readOne() + if err != nil { + return errors.WithMessage(err, "read") + } + + // For object EOF, we should only consume total 3bytes. + if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd { + // 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte. + p = p[1:] + return nil + } + + if err := pushOne(u, a); err != nil { + return errors.WithMessage(err, "push") + } + } + + for len(v.properties) < maxElems { + u, a, err := readOne() + if err != nil { + return errors.WithMessage(err, "read") + } + + if err := pushOne(u, a); err != nil { + return errors.WithMessage(err, "push") + } + } + + return +} + +func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { + v.lock.Lock() + defer v.lock.Unlock() + + var pb []byte + for _, p := range v.properties { + key, value := p.key, p.value + + if pb, err = key.MarshalBinary(); err != nil { + return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return errors.Wrapf(err, "write %v", string(key)) + } + + if pb, err = value.MarshalBinary(); err != nil { + return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return errors.Wrapf(err, "marshal value for %v", string(key)) + } + } + + return +} + +// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type +type amf0Object struct { + amf0ObjectBase + eof amf0ObjectEOF +} + +func NewAmf0Object() *amf0Object { + v := &amf0Object{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0Object) amf0Marker() amf0Marker { + return amf0MarkerObject +} + +func (v *amf0Object) Size() int { + return int(1) + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerObject { + return errors.Errorf("Amf0Object amf0Marker %v is illegal", m) + } + p = p[1:] + + if err = v.unmarshal(p, true, -1); err != nil { + return errors.WithMessage(err, "unmarshal") + } + + return +} + +func (v *amf0Object) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type +type amf0EcmaArray struct { + amf0ObjectBase + count uint32 + eof amf0ObjectEOF +} + +func NewAmf0EcmaArray() *amf0EcmaArray { + v := &amf0EcmaArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0EcmaArray) amf0Marker() amf0Marker { + return amf0MarkerEcmaArray +} + +func (v *amf0EcmaArray) Size() int { + return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return errors.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray { + return errors.Errorf("EcmaArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if err = v.unmarshal(p, true, -1); err != nil { + return errors.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type +type amf0StrictArray struct { + amf0ObjectBase + count uint32 +} + +func NewAmf0StrictArray() *amf0StrictArray { + v := &amf0StrictArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0StrictArray) amf0Marker() amf0Marker { + return amf0MarkerStrictArray +} + +func (v *amf0StrictArray) Size() int { + return int(1) + 4 + v.amf0ObjectBase.Size() +} + +func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return errors.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerStrictArray { + return errors.Errorf("StrictArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if int(v.count) <= 0 { + return + } + + if err = v.unmarshal(p, false, int(v.count)); err != nil { + return errors.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + return b.Bytes(), nil +} + +// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined. +type amf0SingleMarkerObject struct { + target amf0Marker +} + +func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject { + return amf0SingleMarkerObject{target: m} +} + +func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker { + return v.target +} + +func (v *amf0SingleMarkerObject) Size() int { + return int(1) +} + +func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != v.target { + return errors.Errorf("%v amf0Marker %v is illegal", v.target, m) + } + return +} + +func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) { + return []byte{byte(v.target)}, nil +} + +// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type +type amf0Null struct { + amf0SingleMarkerObject +} + +func NewAmf0Null() *amf0Null { + v := amf0Null{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull) + return &v +} + +// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type +type amf0Undefined struct { + amf0SingleMarkerObject +} + +func NewAmf0Undefined() amf0Any { + v := amf0Undefined{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined) + return &v +} + +// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type +type amf0Boolean bool + +func NewAmf0Boolean(b bool) amf0Any { + v := amf0Boolean(b) + return &v +} + +func (v *amf0Boolean) amf0Marker() amf0Marker { + return amf0MarkerBoolean +} + +func (v *amf0Boolean) Size() int { + return int(2) +} + +func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return errors.Errorf("require 2 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerBoolean { + return errors.Errorf("BOOL amf0Marker %v is illegal", m) + } + if p[1] == 0 { + *v = false + } else { + *v = true + } + return +} + +func (v *amf0Boolean) MarshalBinary() (data []byte, err error) { + var b byte + if *v { + b = 1 + } + return []byte{byte(amf0MarkerBoolean), b}, nil +} diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go new file mode 100644 index 0000000000..ee0970e960 --- /dev/null +++ b/proxy/rtmp/rtmp.go @@ -0,0 +1,1792 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bufio" + "bytes" + "context" + "encoding" + "encoding/binary" + "fmt" + "io" + "math/rand" + "sync" + + "srs-proxy/errors" +) + +// The handshake implements the RTMP handshake protocol. +type Handshake struct { + // The random number generator. + r *rand.Rand + // The c1s1 cache. + c1s1 []byte +} + +func NewHandshake(r *rand.Rand) *Handshake { + return &Handshake{r: r} +} + +func (v *Handshake) C1S1() []byte { + return v.c1s1 +} + +func (v *Handshake) WriteC0S0(w io.Writer) (err error) { + r := bytes.NewReader([]byte{0x03}) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c0s0") + } + + return +} + +func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1); err != nil { + return nil, errors.Wrap(err, "read c0s0") + } + + c0 = b.Bytes() + + return +} + +func (v *Handshake) WriteC1S1(w io.Writer) (err error) { + p := make([]byte, 1536) + + for i := 8; i < len(p); i++ { + p[i] = byte(v.r.Int()) + } + + r := bytes.NewReader(p) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c0s1") + } + + return +} + +func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, errors.Wrap(err, "read c1s1") + } + + c1s1 = b.Bytes() + v.c1s1 = c1s1 + + return +} + +func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { + r := bytes.NewReader(s1c1[:]) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c2s2") + } + + return +} + +func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, errors.Wrap(err, "read c2s2") + } + + c2 = b.Bytes() + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 16, @section 6.1. Chunk Format +// Extended timestamp: 0 or 4 bytes +// This field MUST be sent when the normal timsestamp is set to +// 0xffffff, it MUST NOT be sent if the normal timestamp is set to +// anything else. So for values less than 0xffffff the normal +// timestamp field SHOULD be used in which case the extended timestamp +// MUST NOT be present. For values greater than or equal to 0xffffff +// the normal timestamp field MUST NOT be used and MUST be set to +// 0xffffff and the extended timestamp MUST be sent. +const extendedTimestamp = uint64(0xffffff) + +// The default chunk size of RTMP is 128 bytes. +const defaultChunkSize = 128 + +// The intput or output settings for RTMP protocol. +type settings struct { + chunkSize uint32 +} + +func newSettings() *settings { + return &settings{ + chunkSize: defaultChunkSize, + } +} + +// The chunk stream which transport a message once. +type chunkStream struct { + format formatType + cid chunkID + header messageHeader + message *Message + count uint64 + extendedTimestamp bool +} + +func newChunkStream() *chunkStream { + return &chunkStream{} +} + +// The protocol implements the RTMP command and chunk stack. +type Protocol struct { + r *bufio.Reader + w *bufio.Writer + input struct { + opt *settings + chunks map[chunkID]*chunkStream + + transactions map[amf0Number]amf0String + ltransactions sync.Mutex + } + output struct { + opt *settings + } +} + +func NewProtocol(rw io.ReadWriter) *Protocol { + v := &Protocol{ + r: bufio.NewReader(rw), + w: bufio.NewWriter(rw), + } + + v.input.opt = newSettings() + v.input.chunks = map[chunkID]*chunkStream{} + v.input.transactions = map[amf0Number]amf0String{} + + v.output.opt = newSettings() + + return v +} + +func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, errors.WithMessage(err, "read message") + } + + var pkt Packet + if pkt, err = v.DecodeMessage(m); err != nil { + return nil, errors.WithMessage(err, "decode message") + } + + if p, ok := pkt.(T); ok { + *ppkt = p + break + } + } + + return +} + +// Deprecated: Please use rtmp.ExpectPacket instead. +func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) { + panic("Please use rtmp.ExpectPacket instead") +} + +func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, errors.WithMessage(err, "read message") + } + + if len(types) == 0 { + return + } + + for _, t := range types { + if m.MessageType == t { + return + } + } + } + + return +} + +func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { + var commandName amf0String + if err = commandName.UnmarshalBinary(p); err != nil { + return nil, errors.WithMessage(err, "unmarshal command name") + } + + switch commandName { + case commandResult, commandError: + var transactionID amf0Number + if err = transactionID.UnmarshalBinary(p[commandName.Size():]); err != nil { + return nil, errors.WithMessage(err, "unmarshal tid") + } + + var requestName amf0String + if err = func() error { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + var ok bool + if requestName, ok = v.input.transactions[transactionID]; !ok { + return errors.Errorf("No matched request for tid=%v", transactionID) + } + delete(v.input.transactions, transactionID) + + return nil + }(); err != nil { + return nil, errors.WithMessage(err, "discovery request name") + } + + switch requestName { + case commandConnect: + return NewConnectAppResPacket(transactionID), nil + case commandCreateStream: + return NewCreateStreamResPacket(transactionID), nil + case commandReleaseStream, commandFCPublish, commandFCUnpublish: + call := NewCallPacket() + call.TransactionID = transactionID + return call, nil + default: + return nil, errors.Errorf("No request for %v", string(requestName)) + } + case commandConnect: + return NewConnectAppPacket(), nil + case commandPublish: + return NewPublishPacket(), nil + case commandPlay: + return NewPlayPacket(), nil + default: + return NewCallPacket(), nil + } +} + +func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { + p := m.Payload[:] + if len(p) == 0 { + return nil, errors.New("Empty packet") + } + + switch m.MessageType { + case MessageTypeAMF3Command, MessageTypeAMF3Data: + p = p[1:] + } + + switch m.MessageType { + case MessageTypeSetChunkSize: + pkt = NewSetChunkSize() + case MessageTypeWindowAcknowledgementSize: + pkt = NewWindowAcknowledgementSize() + case MessageTypeSetPeerBandwidth: + pkt = NewSetPeerBandwidth() + case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: + if pkt, err = v.parseAMFObject(p); err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + } + case MessageTypeUserControl: + pkt = NewUserControl() + default: + return nil, errors.Errorf("Unknown message %v", m.MessageType) + } + + if err = pkt.UnmarshalBinary(p); err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + } + + return +} + +func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { + for m == nil { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + var cid chunkID + var format formatType + if format, cid, err = v.readBasicHeader(ctx); err != nil { + return nil, errors.WithMessage(err, "read basic header") + } + + var ok bool + var chunk *chunkStream + if chunk, ok = v.input.chunks[cid]; !ok { + chunk = newChunkStream() + v.input.chunks[cid] = chunk + chunk.header.betterCid = cid + } + + if err = v.readMessageHeader(ctx, chunk, format); err != nil { + return nil, errors.WithMessage(err, "read message header") + } + + if m, err = v.readMessagePayload(ctx, chunk); err != nil { + return nil, errors.WithMessage(err, "read message payload") + } + + if err = v.onMessageArrivated(m); err != nil { + return nil, errors.WithMessage(err, "on message") + } + } + + return +} + +func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) { + // Empty payload message. + if chunk.message.payloadLength == 0 { + m = chunk.message + chunk.message = nil + return + } + + // Calculate the chunk payload size. + chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload) + if chunkedPayloadSize > int(v.input.opt.chunkSize) { + chunkedPayloadSize = int(v.input.opt.chunkSize) + } + + b := make([]byte, chunkedPayloadSize) + if _, err = io.ReadFull(v.r, b); err != nil { + return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) + } + chunk.message.Payload = append(chunk.message.Payload, b...) + + // Got entire RTMP message? + if int(chunk.message.payloadLength) == len(chunk.message.Payload) { + m = chunk.message + chunk.message = nil + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 18, @section 6.1.2. Chunk Message Header +// There are four different formats for the chunk message header, +// selected by the "fmt" field in the chunk basic header. +type formatType uint8 + +const ( + // 6.1.2.1. Type 0 + // Chunks of Type 0 are 11 bytes long. This type MUST be used at the + // start of a chunk stream, and whenever the stream timestamp goes + // backward (e.g., because of a backward seek). + formatType0 formatType = iota + // 6.1.2.2. Type 1 + // Chunks of Type 1 are 7 bytes long. The message stream ID is not + // included; this chunk takes the same stream ID as the preceding chunk. + // Streams with variable-sized messages (for example, many video + // formats) SHOULD use this format for the first chunk of each new + // message after the first. + formatType1 + // 6.1.2.3. Type 2 + // Chunks of Type 2 are 3 bytes long. Neither the stream ID nor the + // message length is included; this chunk has the same stream ID and + // message length as the preceding chunk. Streams with constant-sized + // messages (for example, some audio and data formats) SHOULD use this + // format for the first chunk of each message after the first. + formatType2 + // 6.1.2.4. Type 3 + // Chunks of Type 3 have no header. Stream ID, message length and + // timestamp delta are not present; chunks of this type take values from + // the preceding chunk. When a single message is split into chunks, all + // chunks of a message except the first one, SHOULD use this type. Refer + // to example 2 in section 6.2.2. Stream consisting of messages of + // exactly the same size, stream ID and spacing in time SHOULD use this + // type for all chunks after chunk of Type 2. Refer to example 1 in + // section 6.2.1. If the delta between the first message and the second + // message is same as the time stamp of first message, then chunk of + // type 3 would immediately follow the chunk of type 0 as there is no + // need for a chunk of type 2 to register the delta. If Type 3 chunk + // follows a Type 0 chunk, then timestamp delta for this Type 3 chunk is + // the same as the timestamp of Type 0 chunk. + formatType3 +) + +// The message header size, index is format. +var messageHeaderSizes = []int{11, 7, 3, 0} + +// Parse the chunk message header. +// 3bytes: timestamp delta, fmt=0,1,2 +// 3bytes: payload length, fmt=0,1 +// 1bytes: message type, fmt=0,1 +// 4bytes: stream id, fmt=0 +// where: +// fmt=0, 0x0X +// fmt=1, 0x4X +// fmt=2, 0x8X +// fmt=3, 0xCX +func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { + // We should not assert anything about fmt, for the first packet. + // (when first packet, the chunk.message is nil). + // the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet. + // the previous packet is: + // 04 // fmt=0, cid=4 + // 00 00 1a // timestamp=26 + // 00 00 9d // payload_length=157 + // 08 // message_type=8(audio) + // 01 00 00 00 // stream_id=1 + // the current packet maybe: + // c4 // fmt=3, cid=4 + // it's ok, for the packet is audio, and timestamp delta is 26. + // the current packet must be parsed as: + // fmt=0, cid=4 + // timestamp=26+26=52 + // payload_length=157 + // message_type=8(audio) + // stream_id=1 + // so we must update the timestamp even fmt=3 for first packet. + // + // The fresh packet used to update the timestamp even fmt=3 for first packet. + // fresh packet always means the chunk is the first one of message. + var isFirstChunkOfMsg bool + if chunk.message == nil { + isFirstChunkOfMsg = true + } + + // But, we can ensure that when a chunk stream is fresh, + // the fmt must be 0, a new stream. + if chunk.count == 0 && format != formatType0 { + // For librtmp, if ping, it will send a fresh stream with fmt=1, + // 0x42 where: fmt=1, cid=2, protocol contorl user-control message + // 0x00 0x00 0x00 where: timestamp=0 + // 0x00 0x00 0x06 where: payload_length=6 + // 0x04 where: message_type=4(protocol control user-control message) + // 0x00 0x06 where: event Ping(0x06) + // 0x00 0x00 0x0d 0x0f where: event data 4bytes ping timestamp. + // @see: https://github.com/ossrs/srs/issues/98 + if chunk.cid == chunkIDProtocolControl && format == formatType1 { + // We accept cid=2, fmt=1 to make librtmp happy. + } else { + return errors.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) + } + } + + // When exists cache msg, means got an partial message, + // the fmt must not be type0 which means new message. + if chunk.message != nil && format == formatType0 { + return errors.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) + } + + // Create msg when new chunk stream start + if chunk.message == nil { + chunk.message = NewMessage() + } + + // Read the message header. + p := make([]byte, messageHeaderSizes[format]) + if _, err = io.ReadFull(v.r, p); err != nil { + return errors.Wrapf(err, "read %vB message header", len(p)) + } + + // Prse the message header. + // 3bytes: timestamp delta, fmt=0,1,2 + // 3bytes: payload length, fmt=0,1 + // 1bytes: message type, fmt=0,1 + // 4bytes: stream id, fmt=0 + // where: + // fmt=0, 0x0X + // fmt=1, 0x4X + // fmt=2, 0x8X + // fmt=3, 0xCX + if format <= formatType2 { + chunk.header.timestampDelta = uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // fmt: 0 + // timestamp: 3 bytes + // If the timestamp is greater than or equal to 16777215 + // (hexadecimal 0x00ffffff), this value MUST be 16777215, and the + // 'extended timestamp header' MUST be present. Otherwise, this value + // SHOULD be the entire timestamp. + // + // fmt: 1 or 2 + // timestamp delta: 3 bytes + // If the delta is greater than or equal to 16777215 (hexadecimal + // 0x00ffffff), this value MUST be 16777215, and the 'extended + // timestamp header' MUST be present. Otherwise, this value SHOULD be + // the entire delta. + chunk.extendedTimestamp = uint64(chunk.header.timestampDelta) >= extendedTimestamp + if !chunk.extendedTimestamp { + // Extended timestamp: 0 or 4 bytes + // This field MUST be sent when the normal timsestamp is set to + // 0xffffff, it MUST NOT be sent if the normal timestamp is set to + // anything else. So for values less than 0xffffff the normal + // timestamp field SHOULD be used in which case the extended timestamp + // MUST NOT be present. For values greater than or equal to 0xffffff + // the normal timestamp field MUST NOT be used and MUST be set to + // 0xffffff and the extended timestamp MUST be sent. + if format == formatType0 { + // 6.1.2.1. Type 0 + // For a type-0 chunk, the absolute timestamp of the message is sent + // here. + chunk.header.Timestamp = uint64(chunk.header.timestampDelta) + } else { + // 6.1.2.2. Type 1 + // 6.1.2.3. Type 2 + // For a type-1 or type-2 chunk, the difference between the previous + // chunk's timestamp and the current chunk's timestamp is sent here. + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + if format <= formatType1 { + payloadLength := uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // For a message, if msg exists in cache, the size must not changed. + // always use the actual msg size to compare, for the cache payload length can changed, + // for the fmt type1(stream_id not changed), user can change the payload + // length(it's not allowed in the continue chunks). + if !isFirstChunkOfMsg && chunk.header.payloadLength != payloadLength { + return errors.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) + } + chunk.header.payloadLength = payloadLength + + chunk.header.MessageType = MessageType(p[0]) + p = p[1:] + + if format == formatType0 { + chunk.header.streamID = uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 + p = p[4:] + } + } + } else { + // Update the timestamp even fmt=3 for first chunk packet + if isFirstChunkOfMsg && !chunk.extendedTimestamp { + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + // Read extended-timestamp + if chunk.extendedTimestamp { + var timestamp uint32 + if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { + return errors.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) + } + + // We always use 31bits timestamp, for some server may use 32bits extended timestamp. + // @see https://github.com/ossrs/srs/issues/111 + timestamp &= 0x7fffffff + + // TODO: FIXME: Support detect the extended timestamp. + // @see http://blog.csdn.net/win_lin/article/details/13363699 + chunk.header.Timestamp = uint64(timestamp) + } + + // The extended-timestamp must be unsigned-int, + // 24bits timestamp: 0xffffff = 16777215ms = 16777.215s = 4.66h + // 32bits timestamp: 0xffffffff = 4294967295ms = 4294967.295s = 1193.046h = 49.71d + // because the rtmp protocol says the 32bits timestamp is about "50 days": + // 3. Byte Order, Alignment, and Time Format + // Because timestamps are generally only 32 bits long, they will roll + // over after fewer than 50 days. + // + // but, its sample says the timestamp is 31bits: + // An application could assume, for example, that all + // adjacent timestamps are within 2^31 milliseconds of each other, so + // 10000 comes after 4000000000, while 3000000000 comes before + // 4000000000. + // and flv specification says timestamp is 31bits: + // Extension of the Timestamp field to form a SI32 value. This + // field represents the upper 8 bits, while the previous + // Timestamp field represents the lower 24 bits of the time in + // milliseconds. + // in a word, 31bits timestamp is ok. + // convert extended timestamp to 31bits. + chunk.header.Timestamp &= 0x7fffffff + + // Copy header to msg + chunk.message.messageHeader = chunk.header + + // Increase the msg count, the chunk stream can accept fmt=1/2/3 message now. + chunk.count++ + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +// The Chunk Basic Header encodes the chunk stream ID and the chunk +// type(represented by fmt field in the figure below). Chunk type +// determines the format of the encoded message header. Chunk Basic +// Header field may be 1, 2, or 3 bytes, depending on the chunk stream +// ID. +// +// The bits 0-5 (least significant) in the chunk basic header represent +// the chunk stream ID. +// +// Chunk stream IDs 2-63 can be encoded in the 1-byte version of this +// field. +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// |fmt| cs id | +// +-+-+-+-+-+-+-+-+ +// Figure 6 Chunk basic header 1 +// +// Chunk stream IDs 64-319 can be encoded in the 2-byte version of this +// field. ID is computed as (the second byte + 64). +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 0 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 7 Chunk basic header 2 +// +// Chunk stream IDs 64-65599 can be encoded in the 3-byte version of +// this field. ID is computed as ((the third byte)*256 + the second byte +// + 64). +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 1 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 8 Chunk basic header 3 +// +// cs id: 6 bits +// fmt: 2 bits +// cs id - 64: 8 or 16 bits +// +// Chunk stream IDs with values 64-319 could be represented by both 2- +// byte version and 3-byte version of this field. +func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { + // 2-63, 1B chunk header + var t uint8 + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrap(err, "read basic header") + } + cid = chunkID(t & 0x3f) + format = formatType((t >> 6) & 0x03) + + if cid > 1 { + return + } + + // 64-319, 2B chunk header + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) + } + cid = chunkID(64 + uint32(t)) + + // 64-65599, 3B chunk header + if cid == 1 { + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) + } + cid += chunkID(uint32(t) * 256) + } + + return +} + +func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { + m := NewMessage() + + if m.Payload, err = pkt.MarshalBinary(); err != nil { + return errors.WithMessage(err, "marshal payload") + } + + m.MessageType = pkt.Type() + m.streamID = uint32(streamID) + m.betterCid = pkt.BetterCid() + + if err = v.WriteMessage(ctx, m); err != nil { + return errors.WithMessage(err, "write message") + } + + if err = v.onPacketWriten(m, pkt); err != nil { + return errors.WithMessage(err, "on write packet") + } + + return +} + +func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { + var tid amf0Number + var name amf0String + + switch pkt := pkt.(type) { + case *ConnectAppPacket: + tid, name = pkt.TransactionID, pkt.CommandName + case *CreateStreamPacket: + tid, name = pkt.TransactionID, pkt.CommandName + case *CallPacket: + tid, name = pkt.TransactionID, pkt.CommandName + } + + if tid > 0 && len(name) > 0 { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + v.input.transactions[tid] = name + } + + return +} + +func (v *Protocol) onMessageArrivated(m *Message) (err error) { + if m == nil { + return + } + + var pkt Packet + switch m.MessageType { + case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: + if pkt, err = v.DecodeMessage(m); err != nil { + return errors.Errorf("decode message %v", m.MessageType) + } + } + + switch pkt := pkt.(type) { + case *SetChunkSize: + v.input.opt.chunkSize = pkt.ChunkSize + } + + return +} + +func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { + m.payloadLength = uint32(len(m.Payload)) + + var c0h, c3h []byte + if c0h, err = m.generateC0Header(); err != nil { + return errors.WithMessage(err, "generate c0 header") + } + if c3h, err = m.generateC3Header(); err != nil { + return errors.WithMessage(err, "generate c3 header") + } + + var h []byte + p := m.Payload + for len(p) > 0 { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + if h == nil { + h = c0h + } else { + h = c3h + } + + if _, err = io.Copy(v.w, bytes.NewReader(h)); err != nil { + return errors.Wrapf(err, "write c0c3 header %x", h) + } + + size := len(p) + if size > int(v.output.opt.chunkSize) { + size = int(v.output.opt.chunkSize) + } + + if _, err = io.Copy(v.w, bytes.NewReader(p[:size])); err != nil { + return errors.Wrapf(err, "write chunk payload %vB", size) + } + p = p[size:] + } + + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + // TODO: FIXME: Use writev to write for high performance. + if err = v.w.Flush(); err != nil { + return errors.Wrapf(err, "flush writer") + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +// 1byte. One byte field to represent the message type. A range of type IDs +// (1-7) are reserved for protocol control messages. +type MessageType uint8 + +const ( + // Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 5. Protocol Control Messages + // RTMP reserves message type IDs 1-7 for protocol control messages. + // These messages contain information needed by the RTM Chunk Stream + // protocol or RTMP itself. Protocol messages with IDs 1 & 2 are + // reserved for usage with RTM Chunk Stream protocol. Protocol messages + // with IDs 3-6 are reserved for usage of RTMP. Protocol message with ID + // 7 is used between edge server and origin server. + MessageTypeSetChunkSize MessageType = 0x01 + MessageTypeAbort MessageType = 0x02 // 0x02 + MessageTypeAcknowledgement MessageType = 0x03 // 0x03 + MessageTypeUserControl MessageType = 0x04 // 0x04 + MessageTypeWindowAcknowledgementSize MessageType = 0x05 // 0x05 + MessageTypeSetPeerBandwidth MessageType = 0x06 // 0x06 + MessageTypeEdgeAndOriginServerCommand MessageType = 0x07 // 0x07 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3. Types of messages + // The server and the client send messages over the network to + // communicate with each other. The messages can be of any type which + // includes audio messages, video messages, command messages, shared + // object messages, data messages, and user control messages. + // + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.4. Audio message + // The client or the server sends this message to send audio data to the + // peer. The message type value of 8 is reserved for audio messages. + MessageTypeAudio MessageType = 0x08 + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.5. Video message + // The client or the server sends this message to send video data to the + // peer. The message type value of 9 is reserved for video messages. + // These messages are large and can delay the sending of other type of + // messages. To avoid such a situation, the video message is assigned + // the lowest priority. + MessageTypeVideo MessageType = 0x09 // 0x09 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.1. Command message + // Command messages carry the AMF-encoded commands between the client + // and the server. These messages have been assigned message type value + // of 20 for AMF0 encoding and message type value of 17 for AMF3 + // encoding. These messages are sent to perform some operations like + // connect, createStream, publish, play, pause on the peer. Command + // messages like onstatus, result etc. are used to inform the sender + // about the status of the requested commands. A command message + // consists of command name, transaction ID, and command object that + // contains related parameters. A client or a server can request Remote + // Procedure Calls (RPC) over streams that are communicated using the + // command messages to the peer. + MessageTypeAMF3Command MessageType = 17 // 0x11 + MessageTypeAMF0Command MessageType = 20 // 0x14 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.2. Data message + // The client or the server sends this message to send Metadata or any + // user data to the peer. Metadata includes details about the + // data(audio, video etc.) like creation time, duration, theme and so + // on. These messages have been assigned message type value of 18 for + // AMF0 and message type value of 15 for AMF3. + MessageTypeAMF0Data MessageType = 18 // 0x12 + MessageTypeAMF3Data MessageType = 15 // 0x0f +) + +// The header of message. +type messageHeader struct { + // 3bytes. + // Three-byte field that contains a timestamp delta of the message. + // @remark, only used for decoding message from chunk stream. + timestampDelta uint32 + // 3bytes. + // Three-byte field that represents the size of the payload in bytes. + // It is set in big-endian format. + payloadLength uint32 + // 1byte. + // One byte field to represent the message type. A range of type IDs + // (1-7) are reserved for protocol control messages. + MessageType MessageType + // 4bytes. + // Four-byte field that identifies the stream of the message. These + // bytes are set in little-endian format. + streamID uint32 + + // The chunk stream id over which transport. + betterCid chunkID + + // Four-byte field that contains a timestamp of the message. + // The 4 bytes are packed in the big-endian order. + // @remark, we use 64bits for large time for jitter detect and for large tbn like HLS. + Timestamp uint64 +} + +// The RTMP message, transport over chunk stream in RTMP. +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +type Message struct { + messageHeader + + // The payload which carries the RTMP packet. + Payload []byte +} + +func NewMessage() *Message { + return &Message{} +} + +func NewStreamMessage(streamID int) *Message { + v := NewMessage() + v.streamID = uint32(streamID) + v.betterCid = chunkIDOverStream + return v +} + +func (v *Message) generateC3Header() ([]byte, error) { + var c3h []byte + if v.Timestamp < extendedTimestamp { + c3h = make([]byte, 1) + } else { + c3h = make([]byte, 1+4) + } + + p := c3h + p[0] = 0xc0 | byte(v.betterCid&0x3f) + p = p[1:] + + // In RTMP protocol, there must not any timestamp in C3 header, + // but actually all products from adobe, such as FMS/AMS and Flash player and FMLE, + // always carry a extended timestamp in C3 header. + // @see: http://blog.csdn.net/win_lin/article/details/13363699 + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c3h, nil +} + +func (v *Message) generateC0Header() ([]byte, error) { + var c0h []byte + if v.Timestamp < extendedTimestamp { + c0h = make([]byte, 1+3+3+1+4) + } else { + c0h = make([]byte, 1+3+3+1+4+4) + } + + p := c0h + p[0] = byte(v.betterCid) & 0x3f + p = p[1:] + + if v.Timestamp < extendedTimestamp { + p[0] = byte(v.Timestamp >> 16) + p[1] = byte(v.Timestamp >> 8) + p[2] = byte(v.Timestamp) + } else { + p[0] = 0xff + p[1] = 0xff + p[2] = 0xff + } + p = p[3:] + + p[0] = byte(v.payloadLength >> 16) + p[1] = byte(v.payloadLength >> 8) + p[2] = byte(v.payloadLength) + p = p[3:] + + p[0] = byte(v.MessageType) + p = p[1:] + + p[0] = byte(v.streamID) + p[1] = byte(v.streamID >> 8) + p[2] = byte(v.streamID >> 16) + p[3] = byte(v.streamID >> 24) + p = p[4:] + + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c0h, nil +} + +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +type chunkID uint32 + +const ( + chunkIDProtocolControl chunkID = 0x02 + chunkIDOverConnection chunkID = 0x03 + chunkIDOverConnection2 chunkID = 0x04 + chunkIDOverStream chunkID = 0x05 + chunkIDOverStream2 chunkID = 0x06 + chunkIDVideo chunkID = 0x07 + chunkIDAudio chunkID = 0x08 +) + +// The Command Name of message. +const ( + commandConnect amf0String = amf0String("connect") + commandCreateStream amf0String = amf0String("createStream") + commandCloseStream amf0String = amf0String("closeStream") + commandPlay amf0String = amf0String("play") + commandPause amf0String = amf0String("pause") + commandOnBWDone amf0String = amf0String("onBWDone") + commandOnStatus amf0String = amf0String("onStatus") + commandResult amf0String = amf0String("_result") + commandError amf0String = amf0String("_error") + commandReleaseStream amf0String = amf0String("releaseStream") + commandFCPublish amf0String = amf0String("FCPublish") + commandFCUnpublish amf0String = amf0String("FCUnpublish") + commandPublish amf0String = amf0String("publish") + commandRtmpSampleAccess amf0String = amf0String("|RtmpSampleAccess") +) + +// The RTMP packet, transport as payload of RTMP message. +type Packet interface { + // Marshaler and unmarshaler + Size() int + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + + // RTMP protocol fields for each packet. + BetterCid() chunkID + Type() MessageType +} + +// A Call packet, both object and args are AMF0 objects. +type objectCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject *amf0Object + Args *amf0Object +} + +func (v *objectCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *objectCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *objectCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + v.CommandObject.Size() + if v.Args != nil { + size += v.Args.Size() + } + return size +} + +func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command") + } + p = p[v.CommandObject.Size():] + + if len(p) == 0 { + return + } + + v.Args = NewAmf0Object() + if err = v.Args.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal args") + } + + return +} + +func (v *objectCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 45, @section 4.1.1. connect +// The client sends the connect command to the server to request +// connection to a server application instance. +type ConnectAppPacket struct { + objectCallPacket +} + +func NewConnectAppPacket() *ConnectAppPacket { + v := &ConnectAppPacket{} + v.CommandName = commandConnect + v.CommandObject = NewAmf0Object() + v.TransactionID = amf0Number(1.0) + return v +} + +func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandConnect { + return errors.Errorf("Invalid command name %v", string(v.CommandName)) + } + + if v.TransactionID != 1.0 { + return errors.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) + } + + return +} + +func (v *ConnectAppPacket) TcUrl() string { + if v.CommandObject != nil { + if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { + return string(*v) + } + } + return "" +} + +// The response for ConnectAppPacket. +type ConnectAppResPacket struct { + objectCallPacket +} + +func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { + v := &ConnectAppResPacket{} + v.CommandName = commandResult + v.CommandObject = NewAmf0Object() + v.Args = NewAmf0Object() + v.TransactionID = tid + return v +} + +func (v *ConnectAppResPacket) SrsID() string { + if v.Args != nil { + if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok { + if v, ok := v.Get("srs_id").(*amf0String); ok { + return string(*v) + } + } + } + return "" +} + +func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandResult { + return errors.Errorf("Invalid command name %v", string(v.CommandName)) + } + + return +} + +// A Call object, command object is variant. +type variantCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject amf0Any // object or null +} + +func (v *variantCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *variantCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *variantCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + + if v.CommandObject != nil { + size += v.CommandObject.Size() + } + + return size +} + +func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if len(p) > 0 { + if v.CommandObject, err = Amf0Discovery(p); err != nil { + return errors.WithMessage(err, "discovery command object") + } + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command object") + } + p = p[v.CommandObject.Size():] + } + + return +} + +func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if v.CommandObject != nil { + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 51, @section 4.1.2. Call +// The call method of the NetConnection object runs remote procedure +// calls (RPC) at the receiving end. The called RPC name is passed as a +// parameter to the call command. +// @remark onStatus packet is a call packet. +type CallPacket struct { + variantCallPacket + Args amf0Any // optional or object or null +} + +func NewCallPacket() *CallPacket { + return &CallPacket{} +} + +func (v *CallPacket) ArgsCode() string { + if v.Args != nil { + if v, ok := v.Args.(*amf0Object); ok { + if code, ok := v.Get("code").(*amf0String); ok { + return string(*code) + } + } + } + return "" +} + +func (v *CallPacket) Size() int { + size := v.variantCallPacket.Size() + + if v.Args != nil { + size += v.Args.Size() + } + + return size +} + +func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if len(p) > 0 { + if v.Args, err = Amf0Discovery(p); err != nil { + return errors.WithMessage(err, "discovery args") + } + if err = v.Args.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal args") + } + } + + return +} + +func (v *CallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 52, @section 4.1.3. createStream +// The client sends this command to the server to create a logical +// channel for message communication The publishing of audio, video, and +// metadata is carried out over stream channel created using the +// createStream command. +type CreateStreamPacket struct { + variantCallPacket +} + +func NewCreateStreamPacket() *CreateStreamPacket { + v := &CreateStreamPacket{} + v.CommandName = commandCreateStream + v.TransactionID = amf0Number(2) + v.CommandObject = NewAmf0Null() + return v +} + +// The response for create stream +type CreateStreamResPacket struct { + variantCallPacket + StreamID amf0Number +} + +func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { + v := &CreateStreamResPacket{} + v.CommandName = commandResult + v.TransactionID = tid + v.CommandObject = NewAmf0Null() + v.StreamID = 0 + return v +} + +func (v *CreateStreamResPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamID.Size() +} + +func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal sid") + } + + return +} + +func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal sid") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish +type PublishPacket struct { + variantCallPacket + StreamName amf0String + StreamType amf0String +} + +func NewPublishPacket() *PublishPacket { + v := &PublishPacket{} + v.CommandName = commandPublish + v.CommandObject = NewAmf0Null() + v.StreamType = "live" + return v +} + +func (v *PublishPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() + v.StreamType.Size() +} + +func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + if err = v.StreamType.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream type") + } + + return +} + +func (v *PublishPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + if pb, err = v.StreamType.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream type") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play +type PlayPacket struct { + variantCallPacket + StreamName amf0String +} + +func NewPlayPacket() *PlayPacket { + v := &PlayPacket{} + v.CommandName = commandPlay + v.CommandObject = NewAmf0Null() + return v +} + +func (v *PlayPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() +} + +func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + return +} + +func (v *PlayPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 31, @section 5.1. Set Chunk Size +// Protocol control message 1, Set Chunk Size, is used to notify the +// peer about the new maximum chunk size. +type SetChunkSize struct { + ChunkSize uint32 +} + +func NewSetChunkSize() *SetChunkSize { + return &SetChunkSize{ + ChunkSize: defaultChunkSize, + } +} + +func (v *SetChunkSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetChunkSize) Type() MessageType { + return MessageTypeSetChunkSize +} + +func (v *SetChunkSize) Size() int { + return 4 +} + +func (v *SetChunkSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.ChunkSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *SetChunkSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.ChunkSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.5. Window Acknowledgement Size (5) +// The client or the server sends this message to inform the peer which +// window size to use when sending acknowledgment. +type WindowAcknowledgementSize struct { + AckSize uint32 +} + +func NewWindowAcknowledgementSize() *WindowAcknowledgementSize { + return &WindowAcknowledgementSize{} +} + +func (v *WindowAcknowledgementSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *WindowAcknowledgementSize) Type() MessageType { + return MessageTypeWindowAcknowledgementSize +} + +func (v *WindowAcknowledgementSize) Size() int { + return 4 +} + +func (v *WindowAcknowledgementSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.AckSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *WindowAcknowledgementSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.AckSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The sender can mark this message hard (0), soft (1), or dynamic (2) +// using the Limit type field. +type LimitType uint8 + +const ( + LimitTypeHard LimitType = iota + LimitTypeSoft + LimitTypeDynamic +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The client or the server sends this message to update the output +// bandwidth of the peer. +type SetPeerBandwidth struct { + Bandwidth uint32 + LimitType LimitType +} + +func NewSetPeerBandwidth() *SetPeerBandwidth { + return &SetPeerBandwidth{} +} + +func (v *SetPeerBandwidth) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetPeerBandwidth) Type() MessageType { + return MessageTypeSetPeerBandwidth +} + +func (v *SetPeerBandwidth) Size() int { + return 4 + 1 +} + +func (v *SetPeerBandwidth) UnmarshalBinary(data []byte) (err error) { + if len(data) < 5 { + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + v.Bandwidth = binary.BigEndian.Uint32(data) + v.LimitType = LimitType(data[4]) + + return +} + +func (v *SetPeerBandwidth) MarshalBinary() (data []byte, err error) { + data = make([]byte, 5) + binary.BigEndian.PutUint32(data, v.Bandwidth) + data[4] = byte(v.LimitType) + + return +} + +type EventType uint16 + +const ( + // Generally, 4bytes event-data + + // The server sends this event to notify the client + // that a stream has become functional and can be + // used for communication. By default, this event + // is sent on ID 0 after the application connect + // command is successfully received from the + // client. The event data is 4-byte and represents + // The stream ID of the stream that became + // Functional. + EventTypeStreamBegin = 0x00 + + // The server sends this event to notify the client + // that the playback of data is over as requested + // on this stream. No more data is sent without + // issuing additional commands. The client discards + // The messages received for the stream. The + // 4 bytes of event data represent the ID of the + // stream on which playback has ended. + EventTypeStreamEOF = 0x01 + + // The server sends this event to notify the client + // that there is no more data on the stream. If the + // server does not detect any message for a time + // period, it can notify the subscribed clients + // that the stream is dry. The 4 bytes of event + // data represent the stream ID of the dry stream. + EventTypeStreamDry = 0x02 + + // The client sends this event to inform the server + // of the buffer size (in milliseconds) that is + // used to buffer any data coming over a stream. + // This event is sent before the server starts + // processing the stream. The first 4 bytes of the + // event data represent the stream ID and the next + // 4 bytes represent the buffer length, in + // milliseconds. + EventTypeSetBufferLength = 0x03 // 8bytes event-data + + // The server sends this event to notify the client + // that the stream is a recorded stream. The + // 4 bytes event data represent the stream ID of + // The recorded stream. + EventTypeStreamIsRecorded = 0x04 + + // The server sends this event to test whether the + // client is reachable. Event data is a 4-byte + // timestamp, representing the local server time + // When the server dispatched the command. The + // client responds with kMsgPingResponse on + // receiving kMsgPingRequest. + EventTypePingRequest = 0x06 + + // The client sends this event to the server in + // Response to the ping request. The event data is + // a 4-byte timestamp, which was received with the + // kMsgPingRequest request. + EventTypePingResponse = 0x07 + + // For PCUC size=3, for example the payload is "00 1A 01", + // it's a FMS control event, where the event type is 0x001a and event data is 0x01, + // please notice that the event data is only 1 byte for this event. + EventTypeFmsEvent0 = 0x1a +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 32, @5.4. User Control Message (4) +// The client or the server sends this message to notify the peer about the user control events. +// This message carries Event type and Event data. +type UserControl struct { + // Event type is followed by Event data. + // @see: SrcPCUCEventType + EventType EventType + // The event data generally in 4bytes. + // @remark for event type is 0x001a, only 1bytes. + // @see SrsPCUCFmsEvent0 + EventData int32 + // 4bytes if event_type is SetBufferLength; otherwise 0. + ExtraData int32 +} + +func NewUserControl() *UserControl { + return &UserControl{} +} + +func (v *UserControl) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *UserControl) Type() MessageType { + return MessageTypeUserControl +} + +func (v *UserControl) Size() int { + size := 2 + + if v.EventType == EventTypeFmsEvent0 { + size += 1 + } else { + size += 4 + } + + if v.EventType == EventTypeSetBufferLength { + size += 4 + } + + return size +} + +func (v *UserControl) UnmarshalBinary(data []byte) (err error) { + if len(data) < 3 { + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + + v.EventType = EventType(binary.BigEndian.Uint16(data)) + if len(data) < v.Size() { + return errors.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) + } + + if v.EventType == EventTypeFmsEvent0 { + v.EventData = int32(uint8(data[2])) + } else { + v.EventData = int32(binary.BigEndian.Uint32(data[2:])) + } + + if v.EventType == EventTypeSetBufferLength { + v.ExtraData = int32(binary.BigEndian.Uint32(data[6:])) + } + + return +} + +func (v *UserControl) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + binary.BigEndian.PutUint16(data, uint16(v.EventType)) + + if v.EventType == EventTypeFmsEvent0 { + data[2] = uint8(v.EventData) + } else { + binary.BigEndian.PutUint32(data[2:], uint32(v.EventData)) + } + + if v.EventType == EventTypeSetBufferLength { + binary.BigEndian.PutUint32(data[6:], uint32(v.ExtraData)) + } + + return +} diff --git a/proxy/signal.go b/proxy/signal.go new file mode 100644 index 0000000000..367543f4a7 --- /dev/null +++ b/proxy/signal.go @@ -0,0 +1,44 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func installSignals(ctx context.Context, cancel context.CancelFunc) { + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + + go func() { + for s := range sc { + logger.Df(ctx, "Got signal %v", s) + cancel() + } + }() +} + +func installForceQuit(ctx context.Context) error { + var forceTimeout time.Duration + if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil { + return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout()) + } else { + forceTimeout = t + } + + go func() { + <-ctx.Done() + time.Sleep(forceTimeout) + logger.Wf(ctx, "Force to exit by timeout") + os.Exit(1) + }() + return nil +} diff --git a/proxy/srs.go b/proxy/srs.go new file mode 100644 index 0000000000..d05a39c610 --- /dev/null +++ b/proxy/srs.go @@ -0,0 +1,553 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "os" + "strconv" + "strings" + "time" + + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ + "github.com/go-redis/redis/v8" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// If server heartbeat in this duration, it's alive. +const srsServerAliveDuration = 300 * time.Second + +// If HLS streaming update in this duration, it's alive. +const srsHLSAliveDuration = 120 * time.Second + +// If WebRTC streaming update in this duration, it's alive. +const srsRTCAliveDuration = 120 * time.Second + +type SRSServer struct { + // The server IP. + IP string `json:"ip,omitempty"` + // The server device ID, configured by user. + DeviceID string `json:"device_id,omitempty"` + // The server id of SRS, store in file, may not change, mandatory. + ServerID string `json:"server_id,omitempty"` + // The service id of SRS, always change when restarted, mandatory. + ServiceID string `json:"service_id,omitempty"` + // The process id of SRS, always change when restarted, mandatory. + PID string `json:"pid,omitempty"` + // The RTMP listen endpoints. + RTMP []string `json:"rtmp,omitempty"` + // The HTTP Stream listen endpoints. + HTTP []string `json:"http,omitempty"` + // The HTTP API listen endpoints. + API []string `json:"api,omitempty"` + // The SRT server listen endpoints. + SRT []string `json:"srt,omitempty"` + // The RTC server listen endpoints. + RTC []string `json:"rtc,omitempty"` + // Last update time. + UpdatedAt time.Time `json:"update_at,omitempty"` +} + +func (v *SRSServer) ID() string { + return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID) +} + +func (v *SRSServer) String() string { + return fmt.Sprintf("%v", v) +} + +func (v *SRSServer) Format(f fmt.State, c rune) { + switch c { + case 'v', 's': + if f.Flag('+') { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID)) + if v.DeviceID != "" { + sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID)) + } + if len(v.RTMP) > 0 { + sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ","))) + } + if len(v.HTTP) > 0 { + sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ","))) + } + if len(v.API) > 0 { + sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ","))) + } + if len(v.SRT) > 0 { + sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ","))) + } + if len(v.RTC) > 0 { + sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ","))) + } + sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999"))) + fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String()) + } else { + fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID()) + } + default: + fmt.Fprintf(f, "%v, fmt=%%%c", v, c) + } +} + +func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { + v := &SRSServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only. +func NewDefaultSRSForDebugging() (*SRSServer, error) { + if envDefaultBackendEnabled() != "on" { + return nil, nil + } + + if envDefaultBackendIP() == "" { + return nil, fmt.Errorf("empty default backend ip") + } + if envDefaultBackendRTMP() == "" { + return nil, fmt.Errorf("empty default backend rtmp") + } + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP = envDefaultBackendIP() + srs.RTMP = []string{envDefaultBackendRTMP()} + srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) + srs.ServiceID = logger.GenerateContextID() + srs.PID = fmt.Sprintf("%v", os.Getpid()) + srs.UpdatedAt = time.Now() + }) + + if envDefaultBackendHttp() != "" { + server.HTTP = []string{envDefaultBackendHttp()} + } + if envDefaultBackendAPI() != "" { + server.API = []string{envDefaultBackendAPI()} + } + if envDefaultBackendRTC() != "" { + server.RTC = []string{envDefaultBackendRTC()} + } + if envDefaultBackendSRT() != "" { + server.SRT = []string{envDefaultBackendSRT()} + } + return server, nil +} + +// SRSLoadBalancer is the interface to load balance the SRS servers. +type SRSLoadBalancer interface { + // Initialize the load balancer. + Initialize(ctx context.Context) error + // Update the backer server. + Update(ctx context.Context, server *SRSServer) error + // Pick a backend server for the specified stream URL. + Pick(ctx context.Context, streamURL string) (*SRSServer, error) + // Load or store the HLS streaming for the specified stream URL. + LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) + // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. + LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) + // Store the WebRTC streaming for the specified stream URL. + StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error + // Load the WebRTC streaming by ufrag, the ICE username. + LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) +} + +// srsLoadBalancer is the global SRS load balancer. +var srsLoadBalancer SRSLoadBalancer + +// srsMemoryLoadBalancer stores state in memory. +type srsMemoryLoadBalancer struct { + // All available SRS servers, key is server ID. + servers sync.Map[string, *SRSServer] + // The picked server to servce client by specified stream URL, key is stream url. + picked sync.Map[string, *SRSServer] + // The HLS streaming, key is stream URL. + hlsStreamURL sync.Map[string, *HLSPlayStream] + // The HLS streaming, key is SPBHID. + hlsSPBHID sync.Map[string, *HLSPlayStream] + // The WebRTC streaming, key is stream URL. + rtcStreamURL sync.Map[string, *RTCConnection] + // The WebRTC streaming, key is ufrag. + rtcUfrag sync.Map[string, *RTCConnection] +} + +func NewMemoryLoadBalancer() SRSLoadBalancer { + return &srsMemoryLoadBalancer{} +} + +func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() + logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server) + } + return nil +} + +func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error { + v.servers.Store(server.ID(), server) + return nil +} + +func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { + // Always proxy to the same server for the same stream URL. + if server, ok := v.picked.Load(streamURL); ok { + return server, nil + } + + // Gather all servers that were alive within the last few seconds. + var servers []*SRSServer + v.servers.Range(func(key string, server *SRSServer) bool { + if time.Since(server.UpdatedAt) < srsServerAliveDuration { + servers = append(servers, server) + } + return true + }) + + // If no servers available, use all possible servers. + if len(servers) == 0 { + v.servers.Range(func(key string, server *SRSServer) bool { + servers = append(servers, server) + return true + }) + } + + // No server found, failed. + if len(servers) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // Pick a server randomly from servers. + server := servers[rand.Intn(len(servers))] + v.picked.Store(streamURL, server) + return server, nil +} + +func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { + // Load the HLS streaming for the SPBHID, for TS files. + if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { + return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) + } else { + return actual, nil + } +} + +func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { + // Update the HLS streaming for the stream URL, for M3u8. + actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) + if actual == nil { + return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL) + } + + // Update the HLS streaming for the SPBHID, for TS files. + v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual) + + return actual, nil +} + +func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + // Update the WebRTC streaming for the stream URL. + v.rtcStreamURL.Store(streamURL, value) + + // Update the WebRTC streaming for the ufrag. + v.rtcUfrag.Store(value.Ufrag, value) + return nil +} + +func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + if actual, ok := v.rtcUfrag.Load(ufrag); !ok { + return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) + } else { + return actual, nil + } +} + +type srsRedisLoadBalancer struct { + // The redis client sdk. + rdb *redis.Client +} + +func NewRedisLoadBalancer() SRSLoadBalancer { + return &srsRedisLoadBalancer{} +} + +func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { + redisDatabase, err := strconv.Atoi(envRedisDB()) + if err != nil { + return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB()) + } + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()), + Password: envRedisPassword(), + DB: redisDatabase, + }) + v.rdb = rdb + + if err := rdb.Ping(ctx).Err(); err != nil { + return errors.Wrapf(err, "unable to connect to redis %v", rdb.String()) + } + logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String()) + + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() + logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server) + } + return nil +} + +func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error { + b, err := json.Marshal(server) + if err != nil { + return errors.Wrapf(err, "marshal server %+v", server) + } + + key := v.redisKeyServer(server.ID()) + if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v server %+v", key, server) + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // Check each server expiration, if not exists in redis, remove from servers. + for i := len(serverKeys) - 1; i >= 0; i-- { + if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil { + serverKeys = append(serverKeys[:i], serverKeys[i+1:]...) + } + } + + // Add server to servers if not exists. + var found bool + for _, serverKey := range serverKeys { + if serverKey == key { + found = true + break + } + } + if !found { + serverKeys = append(serverKeys, key) + } + + // Update all servers to redis. + b, err = json.Marshal(serverKeys) + if err != nil { + return errors.Wrapf(err, "marshal servers %+v", serverKeys) + } + if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil { + return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys) + } + + return nil +} + +func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { + key := fmt.Sprintf("srs-proxy-url:%v", streamURL) + + // Always proxy to the same server for the same stream URL. + if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil { + // If server not exists, ignore and pick another server for the stream URL. + if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 { + var server SRSServer + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b)) + } + + // TODO: If server fail, we should migrate the streams to another server. + return &server, nil + } + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // No server found, failed. + if len(serverKeys) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // All server should be alive, if not, should have been removed by redis. So we only + // random pick one that is always available. + var serverKey string + var server SRSServer + for i := 0; i < 3; i++ { + tryServerKey := serverKeys[rand.Intn(len(serverKeys))] + b, err := v.rdb.Get(ctx, tryServerKey).Bytes() + if err == nil && len(b) > 0 { + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b)) + } + + serverKey = tryServerKey + break + } + } + if serverKey == "" { + return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL) + } + + // Update the picked server for the stream URL. + if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey) + } + + return &server, nil +} + +func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { + key := v.redisKeySPBHID(spbhid) + + b, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actual HLSPlayStream + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { + b, err := json.Marshal(value) + if err != nil { + return nil, errors.Wrapf(err, "marshal HLS %v", value) + } + + key := v.redisKeyHLS(streamURL) + if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value) + } + + key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID) + if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value) + } + + // Query the HLS streaming from redis. + b2, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actual HLSPlayStream + if err := json.Unmarshal(b2, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + b, err := json.Marshal(value) + if err != nil { + return errors.Wrapf(err, "marshal WebRTC %v", value) + } + + key := v.redisKeyRTC(streamURL) + if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key, value) + } + + key2 := v.redisKeyUfrag(value.Ufrag) + if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value) + } + + return nil +} + +func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + key := v.redisKeyUfrag(ufrag) + + b, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v WebRTC", key) + } + + var actual RTCConnection + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string { + return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) +} + +func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string { + return fmt.Sprintf("srs-proxy-rtc:%v", streamURL) +} + +func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { + return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) +} + +func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string { + return fmt.Sprintf("srs-proxy-hls:%v", streamURL) +} + +func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string { + return fmt.Sprintf("srs-proxy-server:%v", serverID) +} + +func (v *srsRedisLoadBalancer) redisKeyServers() string { + return fmt.Sprintf("srs-proxy-all-servers") +} diff --git a/proxy/srt.go b/proxy/srt.go new file mode 100644 index 0000000000..e4c629af8d --- /dev/null +++ b/proxy/srt.go @@ -0,0 +1,574 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "net" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to +// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the +// backend server. +type srsSRTServer struct { + // The UDP listener for SRT server. + listener *net.UDPConn + + // The SRT connections, identify by the socket ID. + sockets sync.Map[uint32, *SRTConnection] + // The system start time. + start time.Time + + // The wait group for server. + wg stdSync.WaitGroup +} + +func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer { + v := &srsSRTServer{ + start: time.Now(), + } + + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsSRTServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsSRTServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envSRTServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "SRT server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := v.listener.ReadFromUDP(buf) + if err != nil { + // TODO: If SRT server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + socketID := srtParseSocketID(data) + + var pkt *SRTHandshakePacket + if srtIsHandshake(data) { + pkt = &SRTHandshakePacket{} + if err := pkt.UnmarshalBinary(data); err != nil { + return err + } + + if socketID == 0 { + socketID = pkt.SRTSocketID + } + } + + conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { + c.ctx = logger.WithContext(ctx) + c.listenerUDP, c.socketID = v.listener, socketID + c.start = v.start + })) + + ctx = conn.ctx + if !ok { + logger.Df(ctx, "Create new SRT connection skt=%v", socketID) + } + + if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil { + return errors.Wrapf(err, "handle packet") + } else if newSocketID != 0 && newSocketID != socketID { + // The connection may use a new socket ID. + // TODO: FIXME: Should cleanup the dead SRT connection. + v.sockets.Store(newSocketID, conn) + } + + return nil +} + +// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT +// connection, identify by the socket ID. +// +// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in +// the client request. The SRTConnection is stateless, and no need to sync between proxy servers. +// +// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the +// client should never switch to another network or port. If this occurs, the client may be served +// by a different proxy server and fail because the other proxy server cannot identify the client. +type SRTConnection struct { + // The stream context for SRT connection. + ctx context.Context + + // The current socket ID. + socketID uint32 + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn + + // Listener start time. + start time.Time + + // Handshake packets with client. + handshake0 *SRTHandshakePacket + handshake1 *SRTHandshakePacket + handshake2 *SRTHandshakePacket + handshake3 *SRTHandshakePacket +} + +func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection { + v := &SRTConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) { + ctx := v.ctx + + // If not handshake, try to proxy to backend directly. + if pkt == nil { + // Proxy client message to backend. + if v.backendUDP != nil { + if _, err := v.backendUDP.Write(data); err != nil { + return v.socketID, errors.Wrapf(err, "write to backend") + } + } + + return v.socketID, nil + } + + // Handle handshake messages. + if err := v.handleHandshake(ctx, pkt, addr, data); err != nil { + return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt) + } + + return v.socketID, nil +} + +func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error { + // Handle handshake 0 and 1 messages. + if pkt.SynCookie == 0 { + // Save handshake 0 packet. + v.handshake0 = pkt + logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0) + + // Response handshake 1. + v.handshake1 = &SRTHandshakePacket{ + ControlFlag: pkt.ControlFlag, + ControlType: 0, + SubType: 0, + AdditionalInfo: 0, + Timestamp: uint32(time.Since(v.start).Microseconds()), + SocketID: pkt.SRTSocketID, + Version: 5, + EncryptionField: 0, + ExtensionField: 0x4A17, + InitSequence: pkt.InitSequence, + MTU: pkt.MTU, + FlowWindow: pkt.FlowWindow, + HandshakeType: 1, + SRTSocketID: pkt.SRTSocketID, + SynCookie: 0x418d5e4e, + PeerIP: net.ParseIP("127.0.0.1"), + } + logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1) + + if b, err := v.handshake1.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 1") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 1") + } + + return nil + } + + // Handle handshake 2 and 3 messages. + // Parse stream id from packet. + streamID, err := pkt.StreamID() + if err != nil { + return errors.Wrapf(err, "parse stream id") + } + + // Save handshake packet. + v.handshake2 = pkt + logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID) + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx, streamID); err != nil { + return errors.Wrapf(err, "connect backend for %v", streamID) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return errors.Errorf("no backend for %v", streamID) + } + + // Proxy handshake 0 to backend server. + if b, err := v.handshake0.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 0") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 0") + } + logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0) + + // Read handshake 1 from backend server. + b := make([]byte, 4096) + handshake1p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 1") + } else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 1") + } + logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p) + + // Proxy handshake 2 to backend server. + handshake2p := *v.handshake2 + handshake2p.SynCookie = handshake1p.SynCookie + if b, err := handshake2p.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 2") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 2") + } + logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p) + + // Read handshake 3 from backend server. + handshake3p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 3") + } else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 3") + } + logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p) + + // Response handshake 3 to client. + v.handshake3 = &*handshake3p + v.handshake3.SynCookie = v.handshake1.SynCookie + v.socketID = handshake3p.SRTSocketID + logger.Df(ctx, "Handshake 3: %v", v.handshake3) + + if b, err := v.handshake3.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 3") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 3") + } + + // Start a goroutine to proxy message from backend to client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + go func() { + for ctx.Err() == nil { + nn, err := v.backendUDP.Read(b) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + return + } + if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + return + } + } + }() + return nil +} + +func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error { + if v.backendUDP != nil { + return nil + } + + // Parse stream id to host and resource. + host, resource, err := parseSRTStreamID(streamID) + if err != nil { + return errors.Wrapf(err, "parse stream id %v", streamID) + } + + if host == "" { + host = "localhost" + } + + streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource)) + if err != nil { + return errors.Wrapf(err, "build stream url %v", streamID) + } + + // Pick a backend SRS server to proxy the SRT stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse UDP port from backend. + if len(backend.SRT) == 0 { + return errors.Errorf("no udp server %v for %v", backend, streamURL) + } + + _, _, udpPort, err := parseListenEndpoint(backend.SRT[0]) + if err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2 +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1 +type SRTHandshakePacket struct { + // F: 1 bit. Packet Type Flag. The control packet has this flag set to + // "1". The data packet has this flag set to "0". + ControlFlag uint8 + // Control Type: 15 bits. Control Packet Type. The use of these bits + // is determined by the control packet type definition. + // Handshake control packets (Control Type = 0x0000) are used to + // exchange peer configurations, to agree on connection parameters, and + // to establish a connection. + ControlType uint16 + // Subtype: 16 bits. This field specifies an additional subtype for + // specific packets. + SubType uint16 + // Type-specific Information: 32 bits. The use of this field depends on + // the particular control packet type. Handshake packets do not use + // this field. + AdditionalInfo uint32 + // Timestamp: 32 bits. + Timestamp uint32 + // Destination Socket ID: 32 bits. + SocketID uint32 + + // Version: 32 bits. A base protocol version number. Currently used + // values are 4 and 5. Values greater than 5 are reserved for future + // use. + Version uint32 + // Encryption Field: 16 bits. Block cipher family and key size. The + // values of this field are described in Table 2. The default value + // is AES-128. + // 0 | No Encryption Advertised + // 2 | AES-128 + // 3 | AES-192 + // 4 | AES-256 + EncryptionField uint16 + // Extension Field: 16 bits. This field is message specific extension + // related to Handshake Type field. The value MUST be set to 0 + // except for the following cases. (1) If the handshake control + // packet is the INDUCTION message, this field is sent back by the + // Listener. (2) In the case of a CONCLUSION message, this field + // value should contain a combination of Extension Type values. + // 0x00000001 | HSREQ + // 0x00000002 | KMREQ + // 0x00000004 | CONFIG + // 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1 + ExtensionField uint16 + // Initial Packet Sequence Number: 32 bits. The sequence number of the + // very first data packet to be sent. + InitSequence uint32 + // Maximum Transmission Unit Size: 32 bits. This value is typically set + // to 1500, which is the default Maximum Transmission Unit (MTU) size + // for Ethernet, but can be less. + MTU uint32 + // Maximum Flow Window Size: 32 bits. The value of this field is the + // maximum number of data packets allowed to be "in flight" (i.e. the + // number of sent packets for which an ACK control packet has not yet + // been received). + FlowWindow uint32 + // Handshake Type: 32 bits. This field indicates the handshake packet + // type. + // 0xFFFFFFFD | DONE + // 0xFFFFFFFE | AGREEMENT + // 0xFFFFFFFF | CONCLUSION + // 0x00000000 | WAVEHAND + // 0x00000001 | INDUCTION + HandshakeType uint32 + // SRT Socket ID: 32 bits. This field holds the ID of the source SRT + // socket from which a handshake packet is issued. + SRTSocketID uint32 + // SYN Cookie: 32 bits. Randomized value for processing a handshake. + // The value of this field is specified by the handshake message + // type. + SynCookie uint32 + // Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's + // sender. The value consists of four 32-bit fields. + PeerIP net.IP + // Extensions. + // Extension Type: 16 bits. The value of this field is used to process + // an integrated handshake. Each extension can have a pair of + // request and response types. + // Extension Length: 16 bits. The length of the Extension Contents + // field in four-byte blocks. + // Extension Contents: variable length. The payload of the extension. + ExtraData []byte +} + +func (v *SRTHandshakePacket) IsData() bool { + return v.ControlFlag == 0x00 +} + +func (v *SRTHandshakePacket) IsControl() bool { + return v.ControlFlag == 0x80 +} + +func (v *SRTHandshakePacket) IsHandshake() bool { + return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00 +} + +func (v *SRTHandshakePacket) StreamID() (string, error) { + p := v.ExtraData + for { + if len(p) < 2 { + return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData)) + } + + extType := binary.BigEndian.Uint16(p) + extSize := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(extSize*4) { + return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData)) + } + + // Ignore other packets except stream id. + if extType != 0x05 { + p = p[extSize*4:] + continue + } + + // We must copy it, because we will decode the stream id. + data := append([]byte{}, p[:extSize*4]...) + + // Reverse the stream id encoded in little-endian to big-endian. + for i := 0; i < len(data); i += 4 { + value := binary.LittleEndian.Uint32(data[i:]) + binary.BigEndian.PutUint32(data[i:], value) + } + + // Trim the trailing zero bytes. + data = bytes.TrimRight(data, "\x00") + return string(data), nil + } +} + +func (v *SRTHandshakePacket) String() string { + return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB", + v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData)) +} + +func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error { + if len(b) < 4 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.ControlFlag = b[0] & 0x80 + v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff + v.SubType = binary.BigEndian.Uint16(b[2:4]) + + if len(b) < 64 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.AdditionalInfo = binary.BigEndian.Uint32(b[4:]) + v.Timestamp = binary.BigEndian.Uint32(b[8:]) + v.SocketID = binary.BigEndian.Uint32(b[12:]) + v.Version = binary.BigEndian.Uint32(b[16:]) + v.EncryptionField = binary.BigEndian.Uint16(b[20:]) + v.ExtensionField = binary.BigEndian.Uint16(b[22:]) + v.InitSequence = binary.BigEndian.Uint32(b[24:]) + v.MTU = binary.BigEndian.Uint32(b[28:]) + v.FlowWindow = binary.BigEndian.Uint32(b[32:]) + v.HandshakeType = binary.BigEndian.Uint32(b[36:]) + v.SRTSocketID = binary.BigEndian.Uint32(b[40:]) + v.SynCookie = binary.BigEndian.Uint32(b[44:]) + + // Only support IPv4. + v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48]) + + v.ExtraData = b[64:] + + return nil +} + +func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) { + b := make([]byte, 64+len(v.ExtraData)) + binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType) + binary.BigEndian.PutUint16(b[2:], v.SubType) + binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo) + binary.BigEndian.PutUint32(b[8:], v.Timestamp) + binary.BigEndian.PutUint32(b[12:], v.SocketID) + binary.BigEndian.PutUint32(b[16:], v.Version) + binary.BigEndian.PutUint16(b[20:], v.EncryptionField) + binary.BigEndian.PutUint16(b[22:], v.ExtensionField) + binary.BigEndian.PutUint32(b[24:], v.InitSequence) + binary.BigEndian.PutUint32(b[28:], v.MTU) + binary.BigEndian.PutUint32(b[32:], v.FlowWindow) + binary.BigEndian.PutUint32(b[36:], v.HandshakeType) + binary.BigEndian.PutUint32(b[40:], v.SRTSocketID) + binary.BigEndian.PutUint32(b[44:], v.SynCookie) + + // Only support IPv4. + ip := v.PeerIP.To4() + b[48] = ip[3] + b[49] = ip[2] + b[50] = ip[1] + b[51] = ip[0] + + if len(v.ExtraData) > 0 { + copy(b[64:], v.ExtraData) + } + + return b, nil +} diff --git a/proxy/sync/map.go b/proxy/sync/map.go new file mode 100644 index 0000000000..75db12f9a9 --- /dev/null +++ b/proxy/sync/map.go @@ -0,0 +1,45 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package sync + +import "sync" + +type Map[K comparable, V any] struct { + m sync.Map +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) + if !ok { + return value, ok + } + return v.(V), ok +} + +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + v, loaded := m.m.LoadAndDelete(key) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + a, loaded := m.m.LoadOrStore(key, value) + return a.(V), loaded +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value any) bool { + return f(key.(K), value.(V)) + }) +} + +func (m *Map[K, V]) Store(key K, value V) { + m.m.Store(key, value) +} diff --git a/proxy/utils.go b/proxy/utils.go new file mode 100644 index 0000000000..f3c3930762 --- /dev/null +++ b/proxy/utils.go @@ -0,0 +1,276 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/binary" + "encoding/json" + stdErr "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "path" + "reflect" + "regexp" + "strconv" + "strings" + "syscall" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { + w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version())) + + b, err := json.Marshal(data) + if err != nil { + apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(b) +} + +func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { + logger.Wf(ctx, "HTTP API error %+v", err) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, fmt.Sprintf("%v", err)) +} + +func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { + // Always support CORS. Note that browser may send origin header for m3u8, but no origin header + // for ts. So we always response CORS header. + if true { + // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, + // headers, expose headers and methods. + w.Header().Set("Access-Control-Allow-Origin", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + w.Header().Set("Access-Control-Allow-Headers", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + w.Header().Set("Access-Control-Allow-Methods", "*") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return true + } + + return false +} + +func parseGracefullyQuitTimeout() (time.Duration, error) { + if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { + return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) + } else { + return t, nil + } +} + +// ParseBody read the body from r, and unmarshal JSON to v. +func ParseBody(r io.ReadCloser, v interface{}) error { + b, err := ioutil.ReadAll(r) + if err != nil { + return errors.Wrapf(err, "read body") + } + defer r.Close() + + if len(b) == 0 { + return nil + } + + if err := json.Unmarshal(b, v); err != nil { + return errors.Wrapf(err, "json unmarshal %v", string(b)) + } + + return nil +} + +// buildStreamURL build as vhost/app/stream for stream URL r. +func buildStreamURL(r string) (string, error) { + u, err := url.Parse(r) + if err != nil { + return "", errors.Wrapf(err, "parse url %v", r) + } + + // If not domain or ip in hostname, it's __defaultVhost__. + defaultVhost := !strings.Contains(u.Hostname(), ".") + + // If hostname is actually an IP address, it's __defaultVhost__. + if ip := net.ParseIP(u.Hostname()); ip.To4() != nil { + defaultVhost = true + } + + if defaultVhost { + return fmt.Sprintf("__defaultVhost__%v", u.Path), nil + } + + // Ignore port, only use hostname as vhost. + return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil +} + +// isPeerClosedError indicates whether peer object closed the connection. +func isPeerClosedError(err error) bool { + causeErr := errors.Cause(err) + + if stdErr.Is(causeErr, io.EOF) { + return true + } + + if stdErr.Is(causeErr, syscall.EPIPE) { + return true + } + + if netErr, ok := causeErr.(*net.OpError); ok { + if sysErr, ok := netErr.Err.(*os.SyscallError); ok { + if stdErr.Is(sysErr.Err, syscall.ECONNRESET) { + return true + } + } + } + + return false +} + +// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL +// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL +// with extension. +func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + hostname := "__defaultVhost__" + if strings.Contains(r.Host, ":") { + if v, _, err := net.SplitHostPort(r.Host); err == nil { + hostname = v + } + } + + var appStream, streamExt string + + // Parse app/stream from query string. + q := r.URL.Query() + if app := q.Get("app"); app != "" { + appStream = "/" + app + } + if stream := q.Get("stream"); stream != "" { + appStream = fmt.Sprintf("%v/%v", appStream, stream) + } + + // Parse app/stream from path. + if appStream == "" { + streamExt = path.Ext(r.URL.Path) + appStream = strings.TrimSuffix(r.URL.Path, streamExt) + } + + unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream) + fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) + return +} + +// rtcIsSTUN returns true if data of UDP payload is a STUN packet. +func rtcIsSTUN(data []byte) bool { + return len(data) > 0 && (data[0] == 0 || data[0] == 1) +} + +// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet. +func rtcIsRTPOrRTCP(data []byte) bool { + return len(data) >= 12 && (data[0]&0xC0) == 0x80 +} + +// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet. +func srtIsHandshake(data []byte) bool { + return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000 +} + +// srtParseSocketID parse the socket id from the SRT packet. +func srtParseSocketID(data []byte) uint32 { + if len(data) >= 16 { + return binary.BigEndian.Uint32(data[12:]) + } + return 0 +} + +// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. +func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { + if true { + ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) + ufragMatch := ufragRe.FindStringSubmatch(sdp) + if len(ufragMatch) <= 1 { + return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) + } + ufrag = ufragMatch[1] + } + + if true { + pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) + pwdMatch := pwdRe.FindStringSubmatch(sdp) + if len(pwdMatch) <= 1 { + return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) + } + pwd = pwdMatch[1] + } + + return ufrag, pwd, nil +} + +// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required). +// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url +func parseSRTStreamID(sid string) (host, resource string, err error) { + if true { + hostRe := regexp.MustCompile(`h=([^,]+)`) + hostMatch := hostRe.FindStringSubmatch(sid) + if len(hostMatch) > 1 { + host = hostMatch[1] + } + } + + if true { + resourceRe := regexp.MustCompile(`r=([^,]+)`) + resourceMatch := resourceRe.FindStringSubmatch(sid) + if len(resourceMatch) <= 1 { + return "", "", errors.Errorf("no resource in sid %v", sid) + } + resource = resourceMatch[1] + } + + return host, resource, nil +} + +// parseListenEndpoint parse the listen endpoint as: +// port The tcp listen port, like 1935. +// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 +func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { + // If no colon in ep, it's port in string. + if !strings.Contains(ep, ":") { + if p, err := strconv.Atoi(ep); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", ep) + } else { + return "tcp", nil, uint16(p), nil + } + } + + // Must be protocol://ip:port schema. + parts := strings.Split(ep, ":") + if len(parts) != 3 { + return "", nil, 0, errors.Errorf("invalid endpoint %v", ep) + } + + if p, err := strconv.Atoi(parts[2]); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2]) + } else { + return parts[0], net.ParseIP(parts[1]), uint16(p), nil + } +} diff --git a/proxy/version.go b/proxy/version.go new file mode 100644 index 0000000000..94f668f96e --- /dev/null +++ b/proxy/version.go @@ -0,0 +1,27 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import "fmt" + +func VersionMajor() int { + return 1 +} + +// VersionMinor specifies the typical version of SRS we adapt to. +func VersionMinor() int { + return 5 +} + +func VersionRevision() int { + return 0 +} + +func Version() string { + return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision()) +} + +func Signature() string { + return "SRSProxy" +} diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf new file mode 100644 index 0000000000..baca5c9f40 --- /dev/null +++ b/trunk/conf/origin1-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19351; +max_connections 1000; +pid objs/origin1.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8081; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19851; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10081; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin1; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf new file mode 100644 index 0000000000..48f6398930 --- /dev/null +++ b/trunk/conf/origin2-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19352; +max_connections 1000; +pid objs/origin2.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8082; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19853; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10082; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin2; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf new file mode 100644 index 0000000000..95624fb773 --- /dev/null +++ b/trunk/conf/origin3-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19353; +max_connections 1000; +pid objs/origin3.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8083; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19852; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10083; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin3; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 2772c0bf21..9e676930f2 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-09-09, Merge [#4158](https://github.com/ossrs/srs/pull/4158): Proxy: Support proxy server for SRS. v7.0.16 (#4158) * v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171) * v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) * v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) diff --git a/trunk/src/app/srs_app_st.cpp b/trunk/src/app/srs_app_st.cpp index 3e21e468cd..466cbe068f 100755 --- a/trunk/src/app/srs_app_st.cpp +++ b/trunk/src/app/srs_app_st.cpp @@ -342,7 +342,12 @@ SrsWaitGroup::SrsWaitGroup() SrsWaitGroup::~SrsWaitGroup() { - wait(); + // In the destructor, we should NOT wait for all coroutines to be done, because user should decide + // to wait or not. Similar to the Go's sync.WaitGroup, it also requires user to wait explicitly. For + // some special use scenarios, such as error handling, for example, if we started three servers with + // wait group, and one of them failed, user may want to return error and quit directly, without wait + // for other running servers to be done. If we wait in the destructor, it will continue to run without + // some servers, in unknown behaviors. srs_cond_destroy(done_); } diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index fed95c499b..458a6c3d84 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 15 +#define VERSION_REVISION 16 #endif \ No newline at end of file