@@ -1,11 +1,14 @@ package toolsdk import ( "bytes" "context" "errors" "fmt" "io" "strings" "sync" "time" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" Expand All @@ -20,6 +23,7 @@ import ( type WorkspaceBashArgs struct { Workspace string `json:"workspace"` Command string `json:"command"` TimeoutMs int `json:"timeout_ms,omitempty"` } type WorkspaceBashResult struct { Expand All @@ -43,9 +47,12 @@ The workspace parameter supports various formats: - workspace.agent (specific agent) - owner/workspace.agent The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms). If the command times out, all output captured up to that point is returned with a cancellation message. Examples: - workspace: "my-workspace", command: "ls -la" - workspace: "john/dev-env", command: "git status" - workspace: "john/dev-env", command: "git status", timeout_ms: 30000 - workspace: "my-workspace.main", command: "docker ps"`, Schema: aisdk.Schema{ Properties: map[string]any{ Expand All @@ -57,18 +64,27 @@ Examples: "type": "string", "description": "The bash command to execute in the workspace.", }, "timeout_ms": map[string]any{ "type": "integer", "description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.", "default": 60000, "minimum": 1, }, }, Required: []string{"workspace", "command"}, }, }, Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) { Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) { if args.Workspace == "" { return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty") } if args.Command == "" { return WorkspaceBashResult{}, xerrors.New("command cannot be empty") } ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min")) defer cancel() // Normalize workspace input to handle various formats workspaceName := NormalizeWorkspaceInput(args.Workspace) Expand Down Expand Up @@ -119,23 +135,42 @@ Examples: } defer session.Close() // Execute command and capture output output, err := session.CombinedOutput(args.Command) // Set default timeout if not specified (60 seconds) timeoutMs := args.TimeoutMs if timeoutMs <= 0 { timeoutMs = 60000 } // Create context with timeout ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond) defer cancel() // Execute command with timeout handling output, err := executeCommandWithTimeout(ctx, session, args.Command) outputStr := strings.TrimSpace(string(output)) // Handle command execution results if err != nil { // Check ifit's an SSH exit error to get theexit code var exitErr *gossh.ExitError if errors.As(err, &exitErr) { // Check if thecommand timed out if errors.Is(context.Cause(ctx), context.DeadlineExceeded) { outputStr += "\nCommand canceled due to timeout" return WorkspaceBashResult{ Output: outputStr, ExitCode:exitErr.ExitStatus() , ExitCode:124 , }, nil } // For other errors, return exit code 1 // Extract exit code from SSH error if available exitCode := 1 var exitErr *gossh.ExitError if errors.As(err, &exitErr) { exitCode = exitErr.ExitStatus() } // For other errors, use standard timeout or generic error code return WorkspaceBashResult{ Output: outputStr, ExitCode:1 , ExitCode:exitCode , }, nil } Expand Down Expand Up @@ -292,3 +327,99 @@ func NormalizeWorkspaceInput(input string) string { return normalized } // executeCommandWithTimeout executes a command with timeout support func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) { // Set up pipes to capture output stdoutPipe, err := session.StdoutPipe() if err != nil { return nil, xerrors.Errorf("failed to create stdout pipe: %w", err) } stderrPipe, err := session.StderrPipe() if err != nil { return nil, xerrors.Errorf("failed to create stderr pipe: %w", err) } // Start the command if err := session.Start(command); err != nil { return nil, xerrors.Errorf("failed to start command: %w", err) } // Create a thread-safe buffer for combined output var output bytes.Buffer var mu sync.Mutex safeWriter := &syncWriter{w: &output, mu: &mu} // Use io.MultiWriter to combine stdout and stderr multiWriter := io.MultiWriter(safeWriter) // Channel to signal when command completes done := make(chan error, 1) // Start goroutine to copy output and wait for completion go func() { // Copy stdout and stderr concurrently var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _, _ = io.Copy(multiWriter, stdoutPipe) }() go func() { defer wg.Done() _, _ = io.Copy(multiWriter, stderrPipe) }() // Wait for all output to be copied wg.Wait() // Wait for the command to complete done <- session.Wait() }() // Wait for either completion or context cancellation select { case err := <-done: // Command completed normally return safeWriter.Bytes(), err case <-ctx.Done(): // Context was canceled (timeout or other cancellation) // Close the session to stop the command _ = session.Close() // Give a brief moment to collect any remaining output timer := time.NewTimer(50 * time.Millisecond) defer timer.Stop() select { case <-timer.C: // Timer expired, return what we have case err := <-done: // Command finished during grace period return safeWriter.Bytes(), err } return safeWriter.Bytes(), context.Cause(ctx) } } // syncWriter is a thread-safe writer type syncWriter struct { w *bytes.Buffer mu *sync.Mutex } func (sw *syncWriter) Write(p []byte) (n int, err error) { sw.mu.Lock() defer sw.mu.Unlock() return sw.w.Write(p) } func (sw *syncWriter) Bytes() []byte { sw.mu.Lock() defer sw.mu.Unlock() return sw.w.Bytes() }