Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge 0f509a7 into c499a48
Browse files Browse the repository at this point in the history
  • Loading branch information
jeevb authored Jun 3, 2023
2 parents c499a48 + 0f509a7 commit fa0f9fd
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 6 deletions.
10 changes: 7 additions & 3 deletions go/tasks/pluginmachinery/google/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package google
type TokenSourceFactoryType = string

const (
TokenSourceTypeDefault = "default"
TokenSourceTypeDefault = "default"
TokenSourceTypeGkeTaskWorkloadIdentity = "gke-task-workload-identity" // #nosec
)

type TokenSourceFactoryConfig struct {
// Type is type of TokenSourceFactory, possible values are 'default' or 'gke'.
// Type is type of TokenSourceFactory, possible values are 'default' or 'gke-task-workload-identity'.
// - 'default' uses default credentials, see https://cloud.google.com/iam/docs/service-accounts#default
Type TokenSourceFactoryType `json:"type" pflag:",Defines type of TokenSourceFactory, possible values are 'default'"`
Type TokenSourceFactoryType `json:"type" pflag:",Defines type of TokenSourceFactory, possible values are 'default' and 'gke-task-workload-identity'"`

// Configuration for GKE task workload identity token source factory
GkeTaskWorkloadIdentityTokenSourceFactoryConfig GkeTaskWorkloadIdentityTokenSourceFactoryConfig `json:"gke-task-workload-identity" pflag:"Extra configuration for GKE task workload identity token source factory"`
}

func GetDefaultConfig() TokenSourceFactoryConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (

type defaultTokenSource struct{}

func (m *defaultTokenSource) GetTokenSource(ctx context.Context, identity Identity) (oauth2.TokenSource, error) {
func (m *defaultTokenSource) GetTokenSource(
ctx context.Context,
identity Identity,
) (oauth2.TokenSource, error) {
return google.DefaultTokenSource(ctx)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package google

import (
"context"

pluginmachinery "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"google.golang.org/api/impersonate"
"google.golang.org/grpc/credentials/oauth"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
)

const (
gcpServiceAccountAnnotationKey = "iam.gke.io/gcp-service-account"
workflowIdentityDocURL = "https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity"
)

var impersonationScopes = []string{"https://www.googleapis.com/auth/bigquery"}

type GkeTaskWorkloadIdentityTokenSourceFactoryConfig struct {
RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote GKE cluster"`
}

type gkeTaskWorkloadIdentityTokenSourceFactory struct {
kubeClient kubernetes.Interface
}

func (m *gkeTaskWorkloadIdentityTokenSourceFactory) getGcpServiceAccount(
ctx context.Context,
identity Identity,
) (string, error) {
if identity.K8sServiceAccount == "" {
identity.K8sServiceAccount = "default"
}
serviceAccount, err := m.kubeClient.CoreV1().ServiceAccounts(identity.K8sNamespace).Get(
ctx,
identity.K8sServiceAccount,
metav1.GetOptions{},
)
if err != nil {
return "", errors.Wrapf(err, "failed to retrieve task k8s service account")
}

for key, value := range serviceAccount.Annotations {
if key == gcpServiceAccountAnnotationKey {
return value, nil
}
}

return "", errors.Errorf(
"[%v] annotation doesn't exist on k8s service account [%v/%v], read more at %v",
gcpServiceAccountAnnotationKey,
identity.K8sNamespace,
identity.K8sServiceAccount,
workflowIdentityDocURL)
}

func (m *gkeTaskWorkloadIdentityTokenSourceFactory) GetTokenSource(
ctx context.Context,
identity Identity,
) (oauth2.TokenSource, error) {
gcpServiceAccount, err := m.getGcpServiceAccount(ctx, identity)
if err != nil {
return oauth.TokenSource{}, err
}

return impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
TargetPrincipal: gcpServiceAccount,
Scopes: impersonationScopes,
})
}

func getKubeClient(
config *GkeTaskWorkloadIdentityTokenSourceFactoryConfig,
) (*kubernetes.Clientset, error) {
var kubeCfg *rest.Config
var err error
if config.RemoteClusterConfig.Enabled {
kubeCfg, err = pluginmachinery.KubeClientConfig(
config.RemoteClusterConfig.Endpoint,
config.RemoteClusterConfig.Auth,
)
if err != nil {
return nil, errors.Wrapf(err, "Error building kubeconfig")
}
} else {
kubeCfg, err = rest.InClusterConfig()
if err != nil {
return nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig")
}
}

kubeClient, err := kubernetes.NewForConfig(kubeCfg)
if err != nil {
return nil, errors.Wrapf(err, "Error building kubernetes clientset")
}
return kubeClient, err
}

func NewGkeTaskWorkloadIdentityTokenSourceFactory(
config *GkeTaskWorkloadIdentityTokenSourceFactoryConfig,
) (TokenSourceFactory, error) {
kubeClient, err := getKubeClient(config)
if err != nil {
return nil, err
}
return &gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient}, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package google

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
)

func TestGetGcpServiceAccount(t *testing.T) {
ctx := context.TODO()

t.Run("get GCP service account", func(t *testing.T) {
kubeClient := fake.NewSimpleClientset(&corev1.ServiceAccount{
ObjectMeta: v1.ObjectMeta{
Name: "name",
Namespace: "namespace",
Annotations: map[string]string{
"owner": "abc",
"iam.gke.io/gcp-service-account": "gcp-service-account",
},
}})
ts := gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient}
gcpServiceAccount, err := ts.getGcpServiceAccount(ctx, Identity{
K8sNamespace: "namespace",
K8sServiceAccount: "name",
})

assert.NoError(t, err)
assert.Equal(t, "gcp-service-account", gcpServiceAccount)
})

t.Run("no GCP service account", func(t *testing.T) {
kubeClient := fake.NewSimpleClientset()
ts := gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient}
_, err := ts.getGcpServiceAccount(ctx, Identity{
K8sNamespace: "namespace",
K8sServiceAccount: "name",
})

assert.ErrorContains(t, err, "failed to retrieve task k8s service account")
})

t.Run("no GCP service account annotation", func(t *testing.T) {
kubeClient := fake.NewSimpleClientset(&corev1.ServiceAccount{
ObjectMeta: v1.ObjectMeta{
Name: "name",
Namespace: "namespace",
Annotations: map[string]string{
"owner": "abc",
},
}})
ts := gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient}
_, err := ts.getGcpServiceAccount(ctx, Identity{
K8sNamespace: "namespace",
K8sServiceAccount: "name",
})

assert.ErrorContains(t, err, "annotation doesn't exist on k8s service account")
})
}
12 changes: 10 additions & 2 deletions go/tasks/pluginmachinery/google/token_source_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ type TokenSourceFactory interface {
}

func NewTokenSourceFactory(config TokenSourceFactoryConfig) (TokenSourceFactory, error) {
if config.Type == TokenSourceTypeDefault {
switch config.Type {
case TokenSourceTypeDefault:
return NewDefaultTokenSourceFactory()
case TokenSourceTypeGkeTaskWorkloadIdentity:
return NewGkeTaskWorkloadIdentityTokenSourceFactory(
&config.GkeTaskWorkloadIdentityTokenSourceFactoryConfig,
)
}

return nil, errors.Errorf("unknown token source type [%v], possible values are: 'default'", config.Type)
return nil, errors.Errorf(
"unknown token source type [%v], possible values are: 'default' and 'gke-task-workload-identity'",
config.Type,
)
}

0 comments on commit fa0f9fd

Please sign in to comment.