Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ the SSH server and handling the connection proxy.
var liteswap string
var skipSettingsCheck bool
var environmentVersion int
var noConfig bool
var multiplex bool

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
Expand Down Expand Up @@ -71,6 +73,9 @@ the SSH server and handling the connection proxy.
cmd.Flags().IntVar(&environmentVersion, "environment-version", defaultEnvironmentVersion, "Environment version for serverless compute")
cmd.Flags().MarkHidden("environment-version")

cmd.Flags().BoolVar(&noConfig, "no-config", false, "Do not write SSH config entry (disables scp/rsync support)")
cmd.Flags().BoolVar(&multiplex, "multiplex", false, "Enable SSH connection multiplexing (ControlMaster) for faster scp/rsync")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// CLI in the proxy mode is executed by the ssh client and can't prompt for input
if proxyMode {
Expand Down Expand Up @@ -109,6 +114,8 @@ the SSH server and handling the connection proxy.
Liteswap: liteswap,
SkipSettingsCheck: skipSettingsCheck,
EnvironmentVersion: environmentVersion,
SkipConfigWrite: noConfig,
Multiplex: multiplex,
AdditionalArgs: args,
}
if err := opts.Validate(); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions experimental/ssh/cmd/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ an SSH host configuration to your SSH config file.
var sshConfigPath string
var shutdownDelay time.Duration
var autoStartCluster bool
var multiplex bool

cmd.Flags().StringVar(&hostName, "name", "", "Host name to use in SSH config")
cmd.MarkFlagRequired("name")
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster when establishing the ssh connection")
cmd.Flags().StringVar(&sshConfigPath, "ssh-config", "", "Path to SSH config file (default ~/.ssh/config)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "SSH server will terminate after this delay if there are no active connections")
cmd.Flags().BoolVar(&multiplex, "multiplex", false, "Enable SSH connection multiplexing (ControlMaster) for faster scp/rsync")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// We want to avoid the situation where the setup command works because it pulls the auth config from a bundle,
Expand All @@ -53,6 +55,7 @@ an SSH host configuration to your SSH config file.
SSHConfigPath: sshConfigPath,
ShutdownDelay: shutdownDelay,
Profile: wsClient.Config.Profile,
Multiplex: multiplex,
}
clientOpts := client.ClientOptions{
ClusterID: setupOpts.ClusterID,
Expand Down
100 changes: 90 additions & 10 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os/signal"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"syscall"
Expand Down Expand Up @@ -99,6 +100,10 @@ type ClientOptions struct {
SkipSettingsCheck bool
// Environment version for serverless compute.
EnvironmentVersion int
// If true, skip writing the SSH config entry in terminal mode.
SkipConfigWrite bool
// If true, enable SSH ControlMaster multiplexing for connection reuse.
Multiplex bool
}

func (o *ClientOptions) Validate() error {
Expand Down Expand Up @@ -207,14 +212,19 @@ func (o *ClientOptions) ToProxyCommand() (string, error) {
}

func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error {
// In proxy mode, the CLI runs as a ProxyCommand subprocess of ssh/scp/rsync.
// Suppress all user-facing output so it doesn't interfere with the parent tool.
if opts.ProxyMode {
ctx = cmdio.MockDiscard(ctx)
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGHUP, syscall.SIGTERM)
go func() {
<-sigCh
cmdio.LogString(ctx, "Received termination signal, cleaning up...")
cancel()
}()

Expand Down Expand Up @@ -350,10 +360,24 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
}

if opts.ProxyMode {
return runSSHProxy(ctx, client, serverPort, clusterID, opts)
err := runSSHProxy(ctx, client, serverPort, clusterID, opts)
// context.Canceled is the normal exit path when the SSH client (scp/rsync) disconnects.
if errors.Is(err, context.Canceled) {
return nil
}
return err
} else if opts.IDE != "" {
return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts)
} else {
hostName := opts.SessionIdentifier()
if !opts.SkipConfigWrite {
if err := writeSSHConfigForConnect(ctx, hostName, userName, keyPath, opts); err != nil {
// Non-fatal: log and continue with the SSH session
log.Warnf(ctx, "Failed to write SSH config entry: %v", err)
} else {
printSSHToolHints(ctx, hostName)
}
}
log.Infof(ctx, "Additional SSH arguments: %v", opts.AdditionalArgs)
return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts)
}
Expand All @@ -377,31 +401,44 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k
return fmt.Errorf("failed to get SSH config path: %w", err)
}

err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts)
err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, opts)
if err != nil {
return fmt.Errorf("failed to ensure SSH config entry: %w", err)
}

return vscode.LaunchIDE(ctx, opts.IDE, connectionName, userName, currentUser.UserName)
}

func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, opts ClientOptions) error {
// Ensure the Include directive exists in the main SSH config
err := sshconfig.EnsureIncludeDirective(ctx, configPath)
if err != nil {
return err
}

// Generate ProxyCommand with server metadata
optsWithMetadata := opts
optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID)

proxyCommand, err := optsWithMetadata.ToProxyCommand()
// Generate ProxyCommand without metadata so the config is resilient to server restarts.
// The inline SSH invocation passes metadata separately for fast first-connection.
proxyCommand, err := opts.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}

hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand)
configOpts := sshconfig.HostConfigOptions{
HostName: hostName,
UserName: userName,
IdentityFile: keyPath,
ProxyCommand: proxyCommand,
}

if opts.Multiplex {
controlPath, cpErr := controlSocketPath(ctx)
if cpErr != nil {
return cpErr
}
configOpts.ControlPath = controlPath
}

hostConfig := sshconfig.GenerateHostConfig(configOpts)

_, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true)
if err != nil {
Expand All @@ -412,6 +449,39 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k
return nil
}

// writeSSHConfigForConnect writes an SSH config entry so that SSH-based tools
// (scp, rsync, sftp) can connect using the same hostname.
func writeSSHConfigForConnect(ctx context.Context, hostName, userName, keyPath string, opts ClientOptions) error {
configPath, err := sshconfig.GetMainConfigPath(ctx)
if err != nil {
return err
}

if opts.Multiplex {
if err := sshconfig.EnsureSocketsDir(ctx); err != nil {
return err
}
}

return ensureSSHConfigEntry(ctx, configPath, hostName, userName, keyPath, opts)
}

// controlSocketPath returns the ControlPath pattern for SSH multiplexing.
func controlSocketPath(ctx context.Context) (string, error) {
socketsDir, err := sshconfig.GetSocketsDir(ctx)
if err != nil {
return "", err
}
return filepath.ToSlash(filepath.Join(socketsDir, "%h")), nil
}

func printSSHToolHints(ctx context.Context, hostName string) {
cmdio.LogString(ctx, fmt.Sprintf("SSH config written for '%s'. You can now use SSH tools in another terminal:", hostName))
cmdio.LogString(ctx, fmt.Sprintf(" scp %s:remote-file local-file", hostName))
cmdio.LogString(ctx, fmt.Sprintf(" rsync -avz %s:remote-dir/ local-dir/", hostName))
cmdio.LogString(ctx, " sftp "+hostName)
}

// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy.
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
// For dedicated clusters, clusterID should be the same as sessionID.
Expand Down Expand Up @@ -580,6 +650,16 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server
if opts.UserKnownHostsFile != "" {
sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile)
}
if opts.Multiplex && runtime.GOOS != "windows" {
cp, cpErr := controlSocketPath(ctx)
if cpErr == nil {
sshArgs = append(sshArgs,
"-o", "ControlMaster=auto",
"-o", "ControlPath="+cp,
"-o", "ControlPersist=10m",
)
}
}
sshArgs = append(sshArgs, hostName)
sshArgs = append(sshArgs, opts.AdditionalArgs...)

Expand Down
27 changes: 25 additions & 2 deletions experimental/ssh/internal/setup/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"path/filepath"
"time"

"github.com/databricks/cli/experimental/ssh/internal/keys"
Expand All @@ -30,6 +31,8 @@ type SetupOptions struct {
Profile string
// Proxy command to use for the SSH connection
ProxyCommand string
// If true, enable SSH ControlMaster multiplexing for connection reuse by scp/rsync/sftp.
Multiplex bool
}

func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) error {
Expand All @@ -49,8 +52,22 @@ func generateHostConfig(ctx context.Context, opts SetupOptions) (string, error)
return "", fmt.Errorf("failed to get local keys folder: %w", err)
}

hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand)
return hostConfig, nil
configOpts := sshconfig.HostConfigOptions{
HostName: opts.HostName,
UserName: "root",
IdentityFile: identityFilePath,
ProxyCommand: opts.ProxyCommand,
}

if opts.Multiplex {
socketsDir, sockErr := sshconfig.GetSocketsDir(ctx)
if sockErr != nil {
return "", sockErr
}
configOpts.ControlPath = filepath.ToSlash(filepath.Join(socketsDir, "%h"))
}

return sshconfig.GenerateHostConfig(configOpts), nil
}

func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {
Expand Down Expand Up @@ -100,6 +117,12 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
return err
}

if opts.Multiplex {
if err := sshconfig.EnsureSocketsDir(ctx); err != nil {
return err
}
}

hostConfig, err := generateHostConfig(ctx, opts)
if err != nil {
return err
Expand Down
38 changes: 38 additions & 0 deletions experimental/ssh/internal/setup/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"time"

Expand Down Expand Up @@ -201,6 +202,43 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) {
assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedPath))
}

func TestGenerateHostConfig_WithMultiplex(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)

clientOpts := client.ClientOptions{
ClusterID: "cluster-123",
AutoStartCluster: true,
ShutdownDelay: 30 * time.Second,
}
proxyCommand, err := clientOpts.ToProxyCommand()
require.NoError(t, err)

opts := SetupOptions{
HostName: "test-host",
ClusterID: "cluster-123",
SSHKeysDir: tmpDir,
ShutdownDelay: 30 * time.Second,
ProxyCommand: proxyCommand,
Multiplex: true,
}

result, err := generateHostConfig(t.Context(), opts)
require.NoError(t, err)

assert.Contains(t, result, "Host test-host")
assert.Contains(t, result, "--cluster=cluster-123")

if runtime.GOOS == "windows" {
assert.NotContains(t, result, "ControlMaster")
} else {
assert.Contains(t, result, "ControlMaster auto")
assert.Contains(t, result, "ControlPath")
assert.Contains(t, result, "ControlPersist 10m")
}
}

func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
tmpDir := t.TempDir()
Expand Down
Loading
Loading