diff --git a/go.mod b/go.mod index f5cb6b34b..44eebf0be 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,6 @@ require ( github.com/tweekmonster/luser v0.0.0-20161003172636-3fa38070dbd7 github.com/wk8/go-ordered-map/v2 v2.0.0 github.com/writeas/go-strip-markdown v2.0.1+incompatible - golang.org/x/crypto v0.52.0 golang.org/x/text v0.37.0 k8s.io/cli-runtime v0.31.1 ) @@ -101,6 +100,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.52.0 // indirect golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect golang.org/x/sync v0.20.0 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index e74b67c23..95db2ddd8 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -59,7 +59,6 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/upgrade" "github.com/brevdev/brev-cli/pkg/cmd/version" "github.com/brevdev/brev-cli/pkg/cmd/workspacegroups" - "github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/featureflag" @@ -343,7 +342,6 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(setupworkspace.NewCmdSetupWorkspace(noLoginCmdStore)) cmd.AddCommand(recreate.NewCmdRecreate(t, loginCmdStore)) - cmd.AddCommand(writeconnectionevent.NewCmdwriteConnectionEvent(t, loginCmdStore)) cmd.AddCommand(updatemodel.NewCmdupdatemodel(t, loginCmdStore)) cmd.AddCommand(feedback.NewCmdFeedback(t, noLoginCmdStore)) } diff --git a/pkg/cmd/copy/copy.go b/pkg/cmd/copy/copy.go index c35bd3b47..bebc1e26c 100644 --- a/pkg/cmd/copy/copy.go +++ b/pkg/cmd/copy/copy.go @@ -18,7 +18,6 @@ import ( breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" - "github.com/brevdev/brev-cli/pkg/writeconnectionevent" "github.com/briandowns/spinner" "github.com/spf13/cobra" @@ -99,8 +98,6 @@ func runCopyCommand(t *terminal.Terminal, cstore CopyStore, source, dest string, return breverrors.WrapAndTrace(err) } - _ = writeconnectionevent.WriteWCEOnEnv(cstore, workspace.DNS) - err = runSCP(t, sshName, localPath, remotePath, isUpload) if err != nil { return breverrors.WrapAndTrace(err) diff --git a/pkg/cmd/exec/exec.go b/pkg/cmd/exec/exec.go index 3e6d2ad82..e7f0cd924 100644 --- a/pkg/cmd/exec/exec.go +++ b/pkg/cmd/exec/exec.go @@ -16,7 +16,6 @@ import ( breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" - "github.com/brevdev/brev-cli/pkg/writeconnectionevent" "github.com/hashicorp/go-multierror" "github.com/spf13/cobra" @@ -229,7 +228,6 @@ func runExecCommand(t *terminal.Terminal, sstore ExecStore, workspaceNameOrID st if err != nil { return breverrors.WrapAndTrace(err) } - _ = writeconnectionevent.WriteWCEOnEnv(sstore, workspace.DNS) err = runSSH(sshName, command) if err != nil { return breverrors.WrapAndTrace(err) @@ -263,7 +261,6 @@ func runExecCommand(t *terminal.Terminal, sstore ExecStore, workspaceNameOrID st "could not connect to instance %q: %w\nPlease check with: brev ls", workspaceNameOrID, err)) } - _ = writeconnectionevent.WriteWCEOnEnv(sstore, workspace.DNS) err = runSSH(sshName, command) if err != nil { return breverrors.WrapAndTrace(err) diff --git a/pkg/cmd/open/open.go b/pkg/cmd/open/open.go index 62399e9e7..f29a402fc 100644 --- a/pkg/cmd/open/open.go +++ b/pkg/cmd/open/open.go @@ -25,7 +25,6 @@ import ( "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" uutil "github.com/brevdev/brev-cli/pkg/util" - "github.com/brevdev/brev-cli/pkg/writeconnectionevent" "github.com/briandowns/spinner" "github.com/hashicorp/go-multierror" "github.com/samber/mo" @@ -341,10 +340,6 @@ func runOpenCommand(t *terminal.Terminal, tstore OpenStore, wsIDOrName string, s if err != nil { return breverrors.WrapAndTrace(err) } - // we don't care about the error here but should log with sentry - // legacy environments wont support this and cause errrors, - // but we don't want to block the user from using vscode - _ = writeconnectionevent.WriteWCEOnEnv(tstore, string(localIdentifier)) err = openEditorWithSSH(t, string(localIdentifier), projPath, tstore, setupDoneString, editorType) if err != nil { if strings.Contains(err.Error(), `"code": executable file not found in $PATH`) { diff --git a/pkg/cmd/shell/shell.go b/pkg/cmd/shell/shell.go index 47ee39254..85ba0a428 100644 --- a/pkg/cmd/shell/shell.go +++ b/pkg/cmd/shell/shell.go @@ -20,8 +20,6 @@ import ( breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" - "github.com/brevdev/brev-cli/pkg/writeconnectionevent" - "github.com/spf13/cobra" ) @@ -100,10 +98,26 @@ func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID return breverrors.WrapAndTrace(err) } } - err = util.PollUntil(s, workspace.ID, "RUNNING", sstore, " waiting for instance to be ready...", pollTimeout) + if workspace.Status != "RUNNING" { + err = util.PollUntil(s, workspace.ID, "RUNNING", sstore, " waiting for instance to be ready...", pollTimeout) + } if err != nil { return breverrors.WrapAndTrace(err) } + + localIdentifier := workspace.GetLocalIdentifier() + if host { + localIdentifier = workspace.GetHostIdentifier() + } + sshName := string(localIdentifier) + + err = runSSHWithOptions(sshName, host, false) + if err == nil { + trackShellAnalytics(sstore, workspace) + return nil + } + _, _ = fmt.Fprintln(os.Stderr, "\nConnection failed, refreshing SSH config and retrying...") + refreshRes := refresh.RunRefreshAsync(sstore) workspace, err = util.GetUserWorkspaceByNameOrIDErr(sstore, workspaceNameOrID) @@ -114,13 +128,6 @@ func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID return breverrors.New("Instance is not running") } - localIdentifier := workspace.GetLocalIdentifier() - if host { - localIdentifier = workspace.GetHostIdentifier() - } - - sshName := string(localIdentifier) - err = refreshRes.Await() if err != nil { return breverrors.WrapAndTrace(err) @@ -129,15 +136,16 @@ func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID if err != nil { return breverrors.WrapAndTrace(err) } - // we don't care about the error here but should log with sentry - // legacy environments wont support this and cause errrors, - // but we don't want to block the user from using the shell - _ = writeconnectionevent.WriteWCEOnEnv(sstore, workspace.DNS) err = runSSH(sshName, host) if err != nil { return breverrors.WrapAndTrace(err) } - // Call analytics for shell + trackShellAnalytics(sstore, workspace) + + return nil +} + +func trackShellAnalytics(sstore ShellStore, workspace *entity.Workspace) { userID := "" user, err := sstore.GetCurrentUser() if err != nil { @@ -146,15 +154,13 @@ func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID userID = user.ID } data := analytics.EventData{ - EventName: "Brev Open", + EventName: "Brev Shell", UserID: userID, Properties: map[string]string{ "instanceId": workspace.ID, }, } _ = analytics.TrackEvent(data) - - return nil } func shellIntoExternalNode(t *terminal.Terminal, sstore ShellStore, node *nodev1.ExternalNode) error { @@ -180,7 +186,7 @@ func shellIntoExternalNode(t *terminal.Terminal, sstore ShellStore, node *nodev1 } func runSSHWithPort(target string, port int32, identityFile string) error { - sshAgentEval := "eval $(ssh-agent -s)" + sshAgentEval := `if [ -z "$SSH_AUTH_SOCK" ]; then eval $(ssh-agent -s) > /dev/null; fi` cmd := fmt.Sprintf("%s && ssh -i %q -o StrictHostKeyChecking=no -p %d %s", sshAgentEval, identityFile, port, target) sshCmd := exec.Command("bash", "-c", cmd) //nolint:gosec //cmd is constructed from API data @@ -201,13 +207,17 @@ func runSSHWithPort(target string, port int32, identityFile string) error { } func runSSH(sshAlias string, host bool) error { - sshAgentEval := "eval $(ssh-agent -s)" + return runSSHWithOptions(sshAlias, host, true) +} + +func runSSHWithOptions(sshAlias string, host bool, printFailureAdvice bool) error { + sshAgentEval := `if [ -z "$SSH_AUTH_SOCK" ]; then eval $(ssh-agent -s) > /dev/null; fi` var cmd string if host { - cmd = fmt.Sprintf("%s && ssh %s", sshAgentEval, sshAlias) + cmd = fmt.Sprintf("%s && ssh -o ConnectTimeout=5 %s", sshAgentEval, sshAlias) } else { // SSH into VM and respect container WORKDIR if containerized, otherwise use default directory - cmd = fmt.Sprintf("%s && ssh -t %s 'DIR=$(readlink -f /proc/1/cwd 2>/dev/null || pwd); cd \"$DIR\" || echo \"Warning: Could not access container directory\" >&2; exec -l ${SHELL:-/bin/sh}'", sshAgentEval, sshAlias) + cmd = fmt.Sprintf("%s && ssh -t -o ConnectTimeout=5 %s 'DIR=$(readlink -f /proc/1/cwd 2>/dev/null || pwd); cd \"$DIR\" || echo \"Warning: Could not access container directory\" >&2; exec -l ${SHELL:-/bin/sh}'", sshAgentEval, sshAlias) } var stderrBuf bytes.Buffer @@ -223,6 +233,9 @@ func runSSH(sshAlias string, host bool) error { err = sshCmd.Run() if err != nil { + if !printFailureAdvice { + return breverrors.WrapAndTrace(err) + } stderrStr := stderrBuf.String() if strings.Contains(stderrStr, "unix_listener") || strings.Contains(stderrStr, "path too long") { fmt.Fprintf(os.Stderr, "\nbrev shell failed: SSH ControlPath socket path is too long for this system.\n") diff --git a/pkg/cmd/util/ssh.go b/pkg/cmd/util/ssh.go index 65ba85598..2e6a77824 100644 --- a/pkg/cmd/util/ssh.go +++ b/pkg/cmd/util/ssh.go @@ -1,6 +1,7 @@ package util import ( + "context" "errors" "fmt" "os/exec" @@ -14,6 +15,14 @@ import ( "github.com/briandowns/spinner" ) +var ( + sshAvailabilityConnectTimeoutSeconds = 3 + sshAvailabilityAttemptTimeout = 5 * time.Second + sshAvailabilityWaitDelay = time.Second + sshAvailabilityRetrySleep = time.Second + sshAvailabilityMaxAttempts = 20 +) + // WorkspacePollingStore is the minimal interface needed for polling workspace state type WorkspacePollingStore interface { GetWorkspace(workspaceID string) (*entity.Workspace, error) @@ -56,27 +65,45 @@ func WaitForSSHToBeAvailable(sshAlias string, s *spinner.Spinner) error { s.Suffix = " waiting for SSH connection to be available" s.Start() for { - cmd := exec.Command("ssh", "-o", "ConnectTimeout=10", sshAlias, "echo", " ") + attempt := counter + 1 + ctx, cancel := context.WithTimeout(context.Background(), sshAvailabilityAttemptTimeout) + cmd := exec.CommandContext(ctx, "ssh", + "-T", + "-o", fmt.Sprintf("ConnectTimeout=%d", sshAvailabilityConnectTimeoutSeconds), + "-o", "ConnectionAttempts=1", + "-o", "BatchMode=yes", + "-o", "NumberOfPasswordPrompts=0", + "-o", "RequestTTY=no", + "-o", "LogLevel=ERROR", + sshAlias, + "true", + ) + cmd.WaitDelay = sshAvailabilityWaitDelay out, err := cmd.CombinedOutput() + timedOut := ctx.Err() == context.DeadlineExceeded + cancel() if err == nil { s.Stop() return nil } - outputStr := string(out) - lines := strings.Split(outputStr, "\n") - stdErr := outputStr - if len(lines) > 1 { - stdErr = lines[1] + stdErr := strings.TrimSpace(string(out)) + if timedOut { + stdErr = fmt.Sprintf("SSH attempt %d timed out after %s", attempt, sshAvailabilityAttemptTimeout) + } else if stdErr == "" { + stdErr = err.Error() } - if counter == 40 || !store.SatisfactorySSHErrMessage(stdErr) { + if counter == sshAvailabilityMaxAttempts || (!timedOut && !store.SatisfactorySSHErrMessage(stdErr)) { s.Stop() return breverrors.WrapAndTrace(errors.New("\n" + stdErr)) } + s.Stop() + _, _ = fmt.Fprintf(s.Writer, "still waiting for SSH connection (attempt %d failed; retrying)\n", attempt) counter++ - time.Sleep(1 * time.Second) + time.Sleep(sshAvailabilityRetrySleep) + s.Start() } } diff --git a/pkg/cmd/util/ssh_test.go b/pkg/cmd/util/ssh_test.go new file mode 100644 index 000000000..3646e34b2 --- /dev/null +++ b/pkg/cmd/util/ssh_test.go @@ -0,0 +1,51 @@ +package util + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/briandowns/spinner" +) + +func TestWaitForSSHToBeAvailableTimesOutStuckSSHAttempt(t *testing.T) { + dir := t.TempDir() + fakeSSH := filepath.Join(dir, "ssh") + err := os.WriteFile(fakeSSH, []byte("#!/bin/sh\nexec sleep 5\n"), 0o755) + if err != nil { + t.Fatalf("write fake ssh: %v", err) + } + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + + originalAttemptTimeout := sshAvailabilityAttemptTimeout + originalWaitDelay := sshAvailabilityWaitDelay + originalRetrySleep := sshAvailabilityRetrySleep + originalMaxAttempts := sshAvailabilityMaxAttempts + sshAvailabilityAttemptTimeout = 50 * time.Millisecond + sshAvailabilityWaitDelay = 50 * time.Millisecond + sshAvailabilityRetrySleep = 0 + sshAvailabilityMaxAttempts = 0 + t.Cleanup(func() { + sshAvailabilityAttemptTimeout = originalAttemptTimeout + sshAvailabilityWaitDelay = originalWaitDelay + sshAvailabilityRetrySleep = originalRetrySleep + sshAvailabilityMaxAttempts = originalMaxAttempts + }) + + s := spinner.New(spinner.CharSets[9], 100*time.Millisecond) + start := time.Now() + err = WaitForSSHToBeAvailable("slow-host", s) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected stuck ssh attempt to fail") + } + if elapsed > 500*time.Millisecond { + t.Fatalf("expected stuck ssh attempt to be killed quickly, took %v", elapsed) + } + if !strings.Contains(err.Error(), "timed out after") { + t.Fatalf("expected timeout error, got %v", err) + } +} diff --git a/pkg/cmd/writeconnectionevent/writeconnectionevent.go b/pkg/cmd/writeconnectionevent/writeconnectionevent.go deleted file mode 100644 index e7440ccf7..000000000 --- a/pkg/cmd/writeconnectionevent/writeconnectionevent.go +++ /dev/null @@ -1,44 +0,0 @@ -package writeconnectionevent - -import ( - "github.com/spf13/cobra" - - breverrors "github.com/brevdev/brev-cli/pkg/errors" - "github.com/brevdev/brev-cli/pkg/terminal" -) - -var ( - short = "TODO" - long = "TODO" - example = "TODO" -) - -type writeConnectionEventStore interface { - WriteConnectionEvent() error -} - -func NewCmdwriteConnectionEvent(t *terminal.Terminal, store writeConnectionEventStore) *cobra.Command { - cmd := &cobra.Command{ - Use: "write-connection-event", - DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - err := RunWriteConnectionEvent(t, args, store) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - }, - } - return cmd -} - -func RunWriteConnectionEvent(_ *terminal.Terminal, _ []string, store writeConnectionEventStore) error { - err := store.WriteConnectionEvent() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} diff --git a/pkg/cmd/writeconnectionevent/writeconnectionevent_test.go b/pkg/cmd/writeconnectionevent/writeconnectionevent_test.go deleted file mode 100644 index 4d2784e3e..000000000 --- a/pkg/cmd/writeconnectionevent/writeconnectionevent_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package writeconnectionevent - -import ( - "testing" - - "github.com/brevdev/brev-cli/pkg/store" - "github.com/brevdev/brev-cli/pkg/terminal" - "github.com/spf13/afero" -) - -func TestRunWriteConnectionEvent(t *testing.T) { - fs := afero.NewMemMapFs() - type args struct { - in0 *terminal.Terminal - in1 []string - store writeConnectionEventStore - } - tests := []struct { - name string - args args - wantErr bool - }{ - // TODO: Add test cases. - { - name: "write connection event", - args: args{ - nil, - []string{}, - store.NewBasicStore().WithFileSystem(fs), - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := RunWriteConnectionEvent(tt.args.in0, tt.args.in1, tt.args.store); (err != nil) != tt.wantErr { - t.Errorf("RunWriteConnectionEvent() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/pkg/writeconnectionevent/writeconnectionevent.go b/pkg/writeconnectionevent/writeconnectionevent.go deleted file mode 100644 index 568420072..000000000 --- a/pkg/writeconnectionevent/writeconnectionevent.go +++ /dev/null @@ -1,70 +0,0 @@ -package writeconnectionevent - -import ( - "net" - "time" - - "github.com/brevdev/brev-cli/pkg/entity" - breverrors "github.com/brevdev/brev-cli/pkg/errors" - "golang.org/x/crypto/ssh" -) - -func runCMDonEnv(privateKey, host, cmd string) error { - signer, err := ssh.ParsePrivateKey([]byte(privateKey)) - if err != nil { - return breverrors.Wrap(err, "unable to parse private key") - } - config := &ssh.ClientConfig{ - User: "ubuntu", - Auth: []ssh.AuthMethod{ - // Use the PublicKeys method for remote authentication. - ssh.PublicKeys(signer), - }, - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - // use OpenSSH's known_hosts file if you care about host validation - return nil - }, - Timeout: 5 * time.Second, - } - - // Connect to the remote server and perform the SSH handshake. - client, err := ssh.Dial("tcp", host+":22", config) - if err != nil { - return breverrors.Wrap(err, "unable to connect") - } - session, err := client.NewSession() - if err != nil { - return breverrors.Wrap(err, "unable to create session: %v") - } - defer session.Close() //nolint:errcheck // defer - out, err := session.CombinedOutput(cmd) - if err != nil { - return breverrors.Wrap(err, "unable to run: %v \n %v"+cmd+string(out)) - } - err = client.Close() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -type wce interface { - GetCurrentUserKeys() (*entity.UserKeys, error) -} - -func WriteWCEOnEnv(store wce, name string) error { - keys, err := store.GetCurrentUserKeys() - if err != nil { - return breverrors.WrapAndTrace(err) - } - err = runCMDonEnv( - keys.PrivateKey, - name, - "sudo brev write-connection-event", - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - return nil -}