Skip to content

Commit

Permalink
Merge pull request #7023 from dolthub/macneale4/fix_dolt_fetch_pull
Browse files Browse the repository at this point in the history
fix dolt fetch and dolt pull commands to properly authenticate
  • Loading branch information
macneale4 authored Nov 21, 2023
2 parents d09c29c + aacfd0a commit dbb5ce3
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 83 deletions.
56 changes: 47 additions & 9 deletions go/cmd/dolt/commands/sqlserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (

const (
LocalConnectionUser = "__dolt_local_user__"
ApiSqleContextKey = "__sqle_context__"
)

// ExternalDisableUsers is called by implementing applications to disable users. This is not used by Dolt itself,
Expand Down Expand Up @@ -384,7 +385,7 @@ func Serve(
}

ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) }
authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
authenticator := newAccessController(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
args = sqle.WithUserPasswordAuth(args, authenticator)
args.TLSConfig = serverConf.TLSConfig

Expand Down Expand Up @@ -587,29 +588,66 @@ func acquireGlobalSqlServerLock(port int, dEnv *env.DoltEnv) (*env.DBLock, error
return &lck, nil
}

// remotesapiAuth facilitates the implementation remotesrv.AccessControl for the remotesapi server.
type remotesapiAuth struct {
// ctxFactory is a function that returns a new sql.Context. This will create a new conext every time it is called,
// so it should be called once per API request.
ctxFactory func() (*sql.Context, error)
rawDb *mysql_db.MySQLDb
}

func newAuthenticator(ctxFactory func() (*sql.Context, error), rawDb *mysql_db.MySQLDb) remotesrv.Authenticator {
func newAccessController(ctxFactory func() (*sql.Context, error), rawDb *mysql_db.MySQLDb) remotesrv.AccessControl {
return &remotesapiAuth{ctxFactory, rawDb}
}

func (r *remotesapiAuth) Authenticate(creds *remotesrv.RequestCredentials) bool {
err := commands.ValidatePasswordWithAuthResponse(r.rawDb, creds.Username, creds.Password)
// ApiAuthenticate checks the provided credentials against the database and return a SQL context if the credentials are
// valid. If the credentials are invalid, then a nil context is returned. Failures to authenticate are logged.
func (r *remotesapiAuth) ApiAuthenticate(ctx context.Context) (context.Context, error) {
creds, err := remotesrv.ExtractBasicAuthCreds(ctx)
if err != nil {
return false
return nil, err
}

ctx, err := r.ctxFactory()
err = commands.ValidatePasswordWithAuthResponse(r.rawDb, creds.Username, creds.Password)
if err != nil {
return false
return nil, fmt.Errorf("API Authentication Failure: %v", err)
}

address := creds.Address
if strings.Index(address, ":") > 0 {
address, _, err = net.SplitHostPort(creds.Address)
if err != nil {
return nil, fmt.Errorf("Invlaid Host string for authentication: %s", creds.Address)
}
}

sqlCtx, err := r.ctxFactory()
if err != nil {
return nil, fmt.Errorf("API Runtime error: %v", err)
}

sqlCtx.Session.SetClient(sql.Client{User: creds.Username, Address: address, Capabilities: 0})

updatedCtx := context.WithValue(ctx, ApiSqleContextKey, sqlCtx)

return updatedCtx, nil
}

func (r *remotesapiAuth) ApiAuthorize(ctx context.Context) (bool, error) {
sqlCtx, ok := ctx.Value(ApiSqleContextKey).(*sql.Context)
if !ok {
return false, fmt.Errorf("Runtime error: could not get SQL context from context")
}
ctx.Session.SetClient(sql.Client{User: creds.Username, Address: creds.Address, Capabilities: 0})

privOp := sql.NewDynamicPrivilegedOperation(plan.DynamicPrivilege_CloneAdmin)
return r.rawDb.UserHasPrivileges(ctx, privOp)

authorized := r.rawDb.UserHasPrivileges(sqlCtx, privOp)

if !authorized {
return false, fmt.Errorf("API Authorization Failure: %s has not been granted CLONE_ADMIN access", sqlCtx.Session.Client().User)

}
return true, nil
}

func LoadClusterTLSConfig(cfg cluster.Config) (*tls.Config, error) {
Expand Down
12 changes: 9 additions & 3 deletions go/libraries/doltcore/dbfactory/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
)

var GRPCDialProviderParam = "__DOLT__grpc_dial_provider"
var GRPCUsernameAuthParam = "__DOLT__grpc_username"

type GRPCRemoteConfig struct {
Endpoint string
Expand Down Expand Up @@ -100,10 +101,15 @@ func (fact DoltRemoteFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFo
var NoCachingParameter = "__dolt__NO_CACHING"

func (fact DoltRemoteFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}, dp GRPCDialProvider) (chunks.ChunkStore, error) {
var user string
if userParam := params[GRPCUsernameAuthParam]; userParam != nil {
user = userParam.(string)
}
cfg, err := dp.GetGRPCDialParams(grpcendpoint.Config{
Endpoint: urlObj.Host,
Insecure: fact.insecure,
WithEnvCreds: true,
Endpoint: urlObj.Host,
Insecure: fact.insecure,
UserIdForOsEnvAuth: user,
WithEnvCreds: true,
})
if err != nil {
return nil, err
Expand Down
37 changes: 34 additions & 3 deletions go/libraries/doltcore/env/grpc_dial_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@ package env

import (
"crypto/tls"
"errors"
"net"
"net/http"
"os"
"runtime"
"strings"
"unicode"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"github.com/dolthub/dolt/go/libraries/doltcore/creds"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/dconfig"
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
)

Expand Down Expand Up @@ -88,9 +92,18 @@ func (p GRPCDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfacto
if config.Creds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(config.Creds))
} else if config.WithEnvCreds {
rpcCreds, err := p.getRPCCreds(endpoint)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
var rpcCreds credentials.PerRPCCredentials
var err error
if config.UserIdForOsEnvAuth != "" {
rpcCreds, err = p.getRPCCredsFromOSEnv(config.UserIdForOsEnvAuth)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
} else {
rpcCreds, err = p.getRPCCreds(endpoint)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
}
if rpcCreds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(rpcCreds))
Expand All @@ -103,6 +116,24 @@ func (p GRPCDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfacto
}, nil
}

// getRPCCredsFromOSEnv returns RPC Credentials for the specified username, using the DOLT_REMOTE_PASSWORD
func (p GRPCDialProvider) getRPCCredsFromOSEnv(username string) (credentials.PerRPCCredentials, error) {
if username == "" {
return nil, errors.New("Runtime error: username must be provided to getRPCCredsFromOSEnv")
}

pass, found := os.LookupEnv(dconfig.EnvDoltRemotePassword)
if !found {
return nil, errors.New("error: must set DOLT_REMOTE_PASSWORD environment variable to use --user param")
}
c := creds.DoltCredsForPass{
Username: username,
Password: pass,
}

return c.RPCCreds(), nil
}

// getRPCCreds returns any RPC credentials available to this dial provider. If a DoltEnv has been configured
// in this dial provider, it will be used to load custom user credentials, otherwise nil will be returned.
func (p GRPCDialProvider) getRPCCreds(endpoint string) (credentials.PerRPCCredentials, error) {
Expand Down
10 changes: 10 additions & 0 deletions go/libraries/doltcore/env/remotes.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ func (r *Remote) GetRemoteDBWithoutCaching(ctx context.Context, nbf *types.NomsB
return doltdb.LoadDoltDBWithParams(ctx, nbf, r.Url, filesys2.LocalFS, params)
}

func (r Remote) WithParams(params map[string]string) Remote {
fetchSpecs := make([]string, len(r.FetchSpecs))
copy(fetchSpecs, r.FetchSpecs)
for k, v := range r.Params {
params[k] = v
}
r.Params = params
return r
}

// PushOptions contains information needed for push for
// one or more branches or a tag for a specific remote database.
type PushOptions struct {
Expand Down
6 changes: 6 additions & 0 deletions go/libraries/doltcore/grpcendpoint/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ type Config struct {
Creds credentials.PerRPCCredentials
WithEnvCreds bool

// If this is non-empty, and WithEnvCreds is true, then the caller is
// requesting to use username/password authentication instead of JWT
// authentication against the gRPC endpoint. Currently, the password
// comes from the OS environment variable DOLT_REMOTE_PASSWORD.
UserIdForOsEnvAuth string

// If non-nil, this is used for transport level security in the dial
// options, instead of a default option based on `Insecure`.
TLSConfig *tls.Config
Expand Down
39 changes: 39 additions & 0 deletions go/libraries/doltcore/remotesrv/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"strings"

"github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"

"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/hash"
Expand Down Expand Up @@ -397,3 +399,40 @@ func getFileReaderAt(path string, offset int64, length int64) (io.ReadCloser, in
r := closerReaderWrapper{io.LimitReader(f, length), f}
return r, fSize, nil
}

// ExtractBasicAuthCreds extracts the username and password from the incoming request. It returns RequestCredentials
// populated with necessary information to authenticate the request. nil and an error will be returned if any error
// occurs.
func ExtractBasicAuthCreds(ctx context.Context) (*RequestCredentials, error) {
if md, ok := metadata.FromIncomingContext(ctx); !ok {
return nil, errors.New("no metadata in context")
} else {
var username string
var password string

auths := md.Get("authorization")
if len(auths) != 1 {
username = "root"
password = ""
} else {
auth := auths[0]
if !strings.HasPrefix(auth, "Basic ") {
return nil, fmt.Errorf("bad request: authorization header did not start with 'Basic '")
}
authTrim := strings.TrimPrefix(auth, "Basic ")
uDec, err := base64.URLEncoding.DecodeString(authTrim)
if err != nil {
return nil, fmt.Errorf("incoming request authorization header failed to decode: %v", err)
}
userPass := strings.Split(string(uDec), ":")
username = userPass[0]
password = userPass[1]
}
addr, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("incoming request had no peer")
}

return &RequestCredentials{Username: username, Password: password, Address: addr.Addr.String()}, nil
}
}
67 changes: 27 additions & 40 deletions go/libraries/doltcore/remotesrv/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@ package remotesrv

import (
"context"
"encoding/base64"
"strings"

"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

Expand All @@ -34,12 +30,20 @@ type RequestCredentials struct {
}

type ServerInterceptor struct {
Lgr *logrus.Entry
Authenticator Authenticator
Lgr *logrus.Entry
AccessController AccessControl
}

type Authenticator interface {
Authenticate(creds *RequestCredentials) bool
// AccessControl is an interface that provides authentication and authorization for the gRPC server.
type AccessControl interface {
// ApiAuthenticate checks the incoming request for authentication credentials and validates them. If the user's
// identity checks out, the returned context will have the sqlContext within it, which contains the user's ID.
// If the user is not legitimate, an error is returned.
ApiAuthenticate(ctx context.Context) (context.Context, error)
// ApiAuthorize checks that the authenticated user has sufficient privileges to perform the requested action.
// Currently, CLONE_ADMIN is required. True and a nil error returned if the user is authorized, otherwise false
// with an error.
ApiAuthorize(ctx context.Context) (bool, error)
}

func (si *ServerInterceptor) Stream() grpc.StreamServerInterceptor {
Expand Down Expand Up @@ -69,40 +73,23 @@ func (si *ServerInterceptor) Options() []grpc.ServerOption {
}
}

// authenticate checks the incoming request for authentication credentials and validates them. If the user is
// legitimate, an authorization check is performed. If no error is returned, the user should be allowed to proceed.
func (si *ServerInterceptor) authenticate(ctx context.Context) error {
if md, ok := metadata.FromIncomingContext(ctx); ok {
var username string
var password string
ctx, err := si.AccessController.ApiAuthenticate(ctx)
if err != nil {
si.Lgr.Warnf("authentication failed: %s", err.Error())
status.Error(codes.Unauthenticated, "unauthenticated")
return err
}

auths := md.Get("authorization")
if len(auths) != 1 {
username = "root"
} else {
auth := auths[0]
if !strings.HasPrefix(auth, "Basic ") {
si.Lgr.Info("incoming request had malformed authentication header")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
authTrim := strings.TrimPrefix(auth, "Basic ")
uDec, err := base64.URLEncoding.DecodeString(authTrim)
if err != nil {
si.Lgr.Infof("incoming request authorization header failed to decode: %v", err)
return status.Error(codes.Unauthenticated, "unauthenticated")
}
userPass := strings.Split(string(uDec), ":")
username = userPass[0]
password = userPass[1]
}
addr, ok := peer.FromContext(ctx)
if !ok {
si.Lgr.Info("incoming request had no peer")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
if authed := si.Authenticator.Authenticate(&RequestCredentials{Username: username, Password: password, Address: addr.Addr.String()}); !authed {
return status.Error(codes.Unauthenticated, "unauthenticated")
}
return nil
// Have a valid user in the context. Check authorization.
if authorized, err := si.AccessController.ApiAuthorize(ctx); !authorized {
si.Lgr.Warnf("authorization failed: %s", err.Error())
status.Error(codes.PermissionDenied, "unauthorized")
return err
}

return status.Error(codes.Unauthenticated, "unauthenticated 1")
// Access Granted.
return nil
}
7 changes: 7 additions & 0 deletions go/libraries/doltcore/sqle/dprocedures/dolt_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
Expand Down Expand Up @@ -73,6 +74,12 @@ func doDoltFetch(ctx *sql.Context, args []string) (int, error) {
return cmdFailure, err
}

if user, hasUser := apr.GetValue(cli.UserFlag); hasUser {
remote = remote.WithParams(map[string]string{
dbfactory.GRPCUsernameAuthParam: user,
})
}

srcDB, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), remote, false)
if err != nil {
return 1, err
Expand Down
Loading

0 comments on commit dbb5ce3

Please sign in to comment.