Skip to content

Commit

Permalink
Respect proxy templates in tsh puttyconfig
Browse files Browse the repository at this point in the history
Updates tsh puttyconfig to resolve hosts appropriately if a match
is found in the users defined proxy templates.

Fixes #45565
  • Loading branch information
rosstimothy committed Sep 5, 2024
1 parent e2385c9 commit 1b4c665
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 51 deletions.
52 changes: 26 additions & 26 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1384,14 +1384,14 @@ func (tc *TeleportClient) RootClusterName(ctx context.Context) (string, error) {
return name, nil
}

type targetNode struct {
hostname string
addr string
type TargetNode struct {
Hostname string
Addr string
}

// getTargetNodes returns a list of node addresses this SSH command needs to
// GetTargetNodes returns a list of node addresses this SSH command needs to
// operate on.
func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]targetNode, error) {
func (tc *TeleportClient) GetTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]TargetNode, error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/getTargetNodes",
Expand All @@ -1400,10 +1400,10 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni
defer span.End()

if options.HostAddress != "" {
return []targetNode{
return []TargetNode{
{
hostname: options.HostAddress,
addr: options.HostAddress,
Hostname: options.HostAddress,
Addr: options.HostAddress,
},
}, nil
}
Expand All @@ -1422,17 +1422,17 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni
return nil, trace.Wrap(err)
}

retval := make([]targetNode, 0, len(nodes))
retval := make([]TargetNode, 0, len(nodes))
for _, resource := range nodes {
server, ok := resource.ResourceWithLabels.(types.Server)
if !ok {
continue
}

// always dial nodes by UUID
retval = append(retval, targetNode{
hostname: server.GetHostname(),
addr: fmt.Sprintf("%s:0", resource.GetName()),
retval = append(retval, TargetNode{
Hostname: server.GetHostname(),
Addr: fmt.Sprintf("%s:0", resource.GetName()),
})
}

Expand All @@ -1447,10 +1447,10 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUni
}

addr := net.JoinHostPort(tc.Host, strconv.Itoa(tc.HostPort))
return []targetNode{
return []TargetNode{
{
hostname: tc.Host,
addr: addr,
Hostname: tc.Host,
Addr: addr,
},
}, nil
}
Expand Down Expand Up @@ -1716,7 +1716,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
defer clt.Close()

// which nodes are we executing this commands on?
nodeAddrs, err := tc.getTargetNodes(ctx, clt.AuthClient, options)
nodeAddrs, err := tc.GetTargetNodes(ctx, clt.AuthClient, options)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -1727,7 +1727,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
if len(nodeAddrs) > 1 {
return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command)
}
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].addr, command, options.LocalCommandExecutor)
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor)
}

// ConnectToNode attempts to establish a connection to the node resolved to by the provided
Expand All @@ -1736,7 +1736,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
// fail the error from the connection attempt with the already provisioned certificates will
// be returned. The client from whichever attempt succeeds first will be returned.
func (tc *TeleportClient) ConnectToNode(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (_ *NodeClient, err error) {
node := nodeName(targetNode{addr: nodeDetails.Addr})
node := nodeName(TargetNode{Addr: nodeDetails.Addr})
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/ConnectToNode",
Expand Down Expand Up @@ -1890,7 +1890,7 @@ func (m MFARequiredUnknownErr) Is(err error) bool {
// if it is required, then the mfa ceremony is attempted. The target host is dialed once the ceremony
// completes and new certificates are retrieved.
func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (*NodeClient, error) {
node := nodeName(targetNode{addr: nodeDetails.Addr})
node := nodeName(TargetNode{Addr: nodeDetails.Addr})
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/connectToNodeWithMFA",
Expand Down Expand Up @@ -1999,11 +1999,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, tc.OnChannelRequest, nil))
}

func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error {
func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error {
cluster := clt.ClusterName()
nodeAddrs := make([]string, 0, len(nodes))
for _, node := range nodes {
nodeAddrs = append(nodeAddrs, node.addr)
nodeAddrs = append(nodeAddrs, node.Addr)
}
ctx, span := tc.Tracer.Start(
ctx,
Expand Down Expand Up @@ -2699,7 +2699,7 @@ type execResult struct {
}

// runCommandOnNodes executes a given bash command on a bunch of remote nodes.
func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error {
func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error {
cluster := clt.ClusterName()
ctx, span := tc.Tracer.Start(
ctx,
Expand All @@ -2717,7 +2717,7 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli
mfaRequiredCheck, err := clt.AuthClient.IsMFARequired(ctx, &proto.IsMFARequiredRequest{
Target: &proto.IsMFARequiredRequest_Node{
Node: &proto.NodeLogin{
Node: nodeName(targetNode{addr: nodes[0].addr}),
Node: nodeName(TargetNode{Addr: nodes[0].Addr}),
Login: tc.Config.HostLogin,
},
},
Expand All @@ -2743,19 +2743,19 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli
gctx,
"teleportClient/executingCommand",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(attribute.String("node", node.addr)),
oteltrace.WithAttributes(attribute.String("node", node.Addr)),
)
defer span.End()

nodeClient, err := tc.ConnectToNode(
ctx,
clt,
NodeDetails{
Addr: node.addr,
Addr: node.Addr,
Namespace: tc.Namespace,
Cluster: cluster,
MFACheck: mfaRequiredCheck,
hostname: node.hostname,
hostname: node.Hostname,
},
tc.Config.HostLogin,
)
Expand Down
14 changes: 7 additions & 7 deletions lib/client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1309,37 +1309,37 @@ func TestGetTargetNodes(t *testing.T) {
host string
port int
clt fakeResourceClient
expected []targetNode
expected []TargetNode
}{
{
name: "options override",
options: SSHOptions{
HostAddress: "test:1234",
},
expected: []targetNode{{hostname: "test:1234", addr: "test:1234"}},
expected: []TargetNode{{Hostname: "test:1234", Addr: "test:1234"}},
},
{
name: "explicit target",
host: "test",
port: 1234,
expected: []targetNode{{hostname: "test", addr: "test:1234"}},
expected: []TargetNode{{Hostname: "test", Addr: "test:1234"}},
},
{
name: "labels",
labels: map[string]string{"foo": "bar"},
expected: []targetNode{{hostname: "labels", addr: "abcd:0"}},
expected: []TargetNode{{Hostname: "labels", Addr: "abcd:0"}},
clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "labels"}}}},
},
{
name: "search",
search: []string{"foo", "bar"},
expected: []targetNode{{hostname: "search", addr: "abcd:0"}},
expected: []TargetNode{{Hostname: "search", Addr: "abcd:0"}},
clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "search"}}}},
},
{
name: "predicate",
predicate: `resource.spec.hostname == "test"`,
expected: []targetNode{{hostname: "predicate", addr: "abcd:0"}},
expected: []TargetNode{{Hostname: "predicate", Addr: "abcd:0"}},
clt: fakeResourceClient{nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}}},
},
}
Expand All @@ -1357,7 +1357,7 @@ func TestGetTargetNodes(t *testing.T) {
},
}

match, err := clt.getTargetNodes(context.Background(), test.clt, test.options)
match, err := clt.GetTargetNodes(context.Background(), test.clt, test.options)
require.NoError(t, err)
require.EqualValues(t, test.expected, match)
})
Expand Down
12 changes: 6 additions & 6 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ func (a sharedAuthClient) Close() error {
}

// nodeName removes the port number from the hostname, if present
func nodeName(node targetNode) string {
if node.hostname != "" {
return node.hostname
func nodeName(node TargetNode) string {
if node.Hostname != "" {
return node.Hostname
}
n, _, err := net.SplitHostPort(node.addr)
n, _, err := net.SplitHostPort(node.Addr)
if err != nil {
return node.addr
return node.Addr
}
return n
}
Expand All @@ -271,7 +271,7 @@ type NodeDetails struct {

// String returns a user-friendly name
func (n NodeDetails) String() string {
parts := []string{nodeName(targetNode{addr: n.Addr})}
parts := []string{nodeName(TargetNode{Addr: n.Addr})}
if n.Cluster != "" {
parts = append(parts, "on cluster", n.Cluster)
}
Expand Down
6 changes: 3 additions & 3 deletions lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ import (
)

func TestHelperFunctions(t *testing.T) {
assert.Equal(t, "one", nodeName(targetNode{addr: "one"}))
assert.Equal(t, "one", nodeName(targetNode{addr: "one:22"}))
assert.Equal(t, "example.com", nodeName(targetNode{addr: "one", hostname: "example.com"}))
assert.Equal(t, "one", nodeName(TargetNode{Addr: "one"}))
assert.Equal(t, "one", nodeName(TargetNode{Addr: "one:22"}))
assert.Equal(t, "example.com", nodeName(TargetNode{Addr: "one", Hostname: "example.com"}))
}

func TestNewSession(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion lib/client/cluster_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe
keyRing, err = c.performMFACeremony(ctx,
mfaClt,
ReissueParams{
NodeName: nodeName(targetNode{addr: target.Addr}),
NodeName: nodeName(TargetNode{Addr: target.Addr}),
RouteToCluster: target.Cluster,
MFACheck: target.MFACheck,
},
Expand Down
29 changes: 21 additions & 8 deletions tool/tsh/common/putty_config_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/utils/keypaths"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/puttyhosts"
"github.com/gravitational/teleport/lib/utils/registry"
)
Expand Down Expand Up @@ -254,10 +255,29 @@ func onPuttyConfig(cf *CLIConf) error {
return trace.Wrap(err)
}

// connect to proxy to fetch cluster info
clusterClient, err := tc.ConnectToCluster(cf.Context)
if err != nil {
return trace.Wrap(err)
}
defer clusterClient.Close()

matches, err := tc.GetTargetNodes(cf.Context, clusterClient.AuthClient, client.SSHOptions{})
if err != nil {
return trace.Wrap(err)
}

switch len(matches) {
case 0:
return trace.NotFound("no matching hosts")
case 1:
return trace.BadParameter("multiple matching hosts found")
}

// remove any spaces from the provided hostname. if the hostname contains a colon, it will be a
// hostname:port combination so we split it. this is useful as shorthand when adding OpenSSH hosts
// with `tsh puttyconfig user@host:22`, rather than using the longer `tsh puttyconfig --port 22 user@host`
hostname := strings.TrimSpace(tc.Config.Host)
hostname := strings.TrimSpace(matches[0].Hostname)
port := tc.Config.HostPort
if splitHost, splitPort, err := net.SplitHostPort(hostname); err == nil {
hostname = splitHost
Expand All @@ -280,13 +300,6 @@ func onPuttyConfig(cf *CLIConf) error {
userHostString = fmt.Sprintf("%v@%v", login, userHostString)
}

// connect to proxy to fetch cluster info
clusterClient, err := tc.ConnectToCluster(cf.Context)
if err != nil {
return trace.Wrap(err)
}
defer clusterClient.Close()

// parse out proxy details
proxyHost, _, err := net.SplitHostPort(tc.Config.SSHProxyAddr)
if err != nil {
Expand Down

0 comments on commit 1b4c665

Please sign in to comment.