Skip to content

Commit aecfa58

Browse files
committed
Merge branch 'develop'
2 parents 9502b46 + 692af6d commit aecfa58

File tree

4 files changed

+140
-14
lines changed

4 files changed

+140
-14
lines changed

src/mcp_shell_server/shell_executor.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import os
3+
import pwd
34
import shlex
45
import time
56
from typing import IO, Any, Dict, List, Optional, Tuple, Union
@@ -14,7 +15,7 @@ def __init__(self):
1415
"""
1516
Initialize the executor.
1617
"""
17-
pass
18+
pass # pragma: no cover
1819

1920
def _get_allowed_commands(self) -> set[str]:
2021
"""Get the set of allowed commands from environment variables"""
@@ -378,6 +379,13 @@ def _preprocess_command(self, command: List[str]) -> List[str]:
378379
preprocessed_command.append(token)
379380
return preprocessed_command
380381

382+
def _get_default_shell(self) -> str:
383+
"""Get the login shell of the current user"""
384+
try:
385+
return pwd.getpwuid(os.getuid()).pw_shell
386+
except (ImportError, KeyError):
387+
return os.environ.get("SHELL", "/bin/sh")
388+
381389
async def execute(
382390
self,
383391
command: List[str],
@@ -517,14 +525,17 @@ async def execute(
517525
except IOError as e:
518526
raise ValueError(f"Failed to open output file: {e}") from e
519527

520-
# Execute the command
528+
# Execute the command with interactive shell
529+
shell = self._get_default_shell()
521530
shell_cmd = self._create_shell_command(cmd)
531+
shell_cmd = f"{shell} -i -c {shlex.quote(shell_cmd)}"
532+
522533
process = await asyncio.create_subprocess_shell(
523534
shell_cmd,
524535
stdin=asyncio.subprocess.PIPE if stdin else None,
525536
stdout=stdout_handle,
526537
stderr=asyncio.subprocess.PIPE,
527-
env={"PATH": os.environ.get("PATH", "")},
538+
env=os.environ, # Use all environment variables
528539
cwd=directory,
529540
)
530541

@@ -642,12 +653,17 @@ async def _execute_pipeline(
642653
for i, cmd in enumerate(parsed_commands):
643654
shell_cmd = self._create_shell_command(cmd)
644655

656+
# Get default shell for the first command and set interactive mode
657+
if i == 0:
658+
shell = self._get_default_shell()
659+
shell_cmd = f"{shell} -i -c {shlex.quote(shell_cmd)}"
660+
645661
process = await asyncio.create_subprocess_shell(
646662
shell_cmd,
647663
stdin=asyncio.subprocess.PIPE if prev_stdout is not None else None,
648664
stdout=asyncio.subprocess.PIPE,
649665
stderr=asyncio.subprocess.PIPE,
650-
env={"PATH": os.environ.get("PATH", "")},
666+
env=os.environ, # Use all environment variables
651667
cwd=directory,
652668
)
653669

tests/test_server.py

+23
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,26 @@ def stdio_server_impl():
322322
await main()
323323

324324
assert str(exc.value) == "Test error"
325+
326+
327+
@pytest.mark.asyncio
328+
async def test_shell_startup(monkeypatch, temp_test_dir):
329+
"""Test shell startup and environment"""
330+
monkeypatch.setenv("ALLOW_COMMANDS", "ps")
331+
result = await call_tool(
332+
"shell_execute",
333+
{"command": ["ps", "-p", "$$", "-o", "command="], "directory": temp_test_dir},
334+
)
335+
assert len(result) == 1
336+
assert result[0].type == "text"
337+
338+
339+
@pytest.mark.asyncio
340+
async def test_environment_variables(monkeypatch, temp_test_dir):
341+
"""Test to check environment variables during test execution"""
342+
monkeypatch.setenv("ALLOW_COMMANDS", "env")
343+
result = await call_tool(
344+
"shell_execute",
345+
{"command": ["env"], "directory": temp_test_dir},
346+
)
347+
assert len(result) == 1

tests/test_shell_executor.py

+42
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,45 @@ def test_preprocess_command_pipeline(executor):
600600
"grep",
601601
"pattern",
602602
]
603+
604+
605+
@pytest.mark.asyncio
606+
async def test_command_cleanup_on_error(executor, temp_test_dir, monkeypatch):
607+
"""Test cleanup of processes when error occurs"""
608+
clear_env(monkeypatch)
609+
monkeypatch.setenv("ALLOW_COMMANDS", "sleep")
610+
611+
async def execute_with_keyboard_interrupt():
612+
# Simulate keyboard interrupt during execution
613+
result = await executor.execute(["sleep", "5"], temp_test_dir, timeout=1)
614+
return result
615+
616+
result = await execute_with_keyboard_interrupt()
617+
assert result["error"] == "Command timed out after 1 seconds"
618+
assert result["status"] == -1
619+
assert "execution_time" in result
620+
621+
622+
@pytest.mark.asyncio
623+
async def test_output_redirection_with_append(executor, temp_test_dir, monkeypatch):
624+
"""Test output redirection with append mode"""
625+
clear_env(monkeypatch)
626+
monkeypatch.setenv("ALLOW_COMMANDS", "echo,cat")
627+
output_file = os.path.join(temp_test_dir, "test.txt")
628+
629+
# Write initial content
630+
await executor.execute(["echo", "hello", ">", output_file], directory=temp_test_dir)
631+
632+
# Append content
633+
result = await executor.execute(
634+
["echo", "world", ">>", output_file], directory=temp_test_dir
635+
)
636+
assert result["error"] is None
637+
assert result["status"] == 0
638+
639+
# Verify contents
640+
result = await executor.execute(["cat", output_file], directory=temp_test_dir)
641+
lines = result["stdout"].strip().split("\n")
642+
assert len(lines) == 2
643+
assert lines[0] == "hello"
644+
assert lines[1] == "world"

tests/test_shell_executor_pipeline.py

+55-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
1+
"""Test pipeline execution and cleanup scenarios."""
2+
3+
import os
4+
import tempfile
5+
16
import pytest
27

38
from mcp_shell_server.shell_executor import ShellExecutor
49

510

11+
def clear_env(monkeypatch):
12+
monkeypatch.delenv("ALLOW_COMMANDS", raising=False)
13+
monkeypatch.delenv("ALLOWED_COMMANDS", raising=False)
14+
15+
16+
@pytest.fixture
17+
def executor():
18+
return ShellExecutor()
19+
20+
21+
@pytest.fixture
22+
def temp_test_dir():
23+
"""Create a temporary directory for testing"""
24+
with tempfile.TemporaryDirectory() as tmpdirname:
25+
# Return the real path to handle macOS /private/tmp symlink
26+
yield os.path.realpath(tmpdirname)
27+
28+
629
@pytest.mark.asyncio
7-
async def test_pipeline_split():
30+
async def test_pipeline_split(executor):
831
"""Test pipeline command splitting functionality"""
9-
executor = ShellExecutor()
10-
1132
# Test basic pipe command
1233
commands = executor._split_pipe_commands(["echo", "hello", "|", "grep", "h"])
1334
assert len(commands) == 2
@@ -35,19 +56,43 @@ async def test_pipeline_split():
3556

3657

3758
@pytest.mark.asyncio
38-
async def test_pipeline_execution_success():
59+
async def test_pipeline_execution_success(executor, temp_test_dir, monkeypatch):
3960
"""Test successful pipeline execution with proper return value"""
40-
executor = ShellExecutor()
41-
import os
42-
43-
os.environ["ALLOWED_COMMANDS"] = "echo,grep"
61+
clear_env(monkeypatch)
62+
monkeypatch.setenv("ALLOWED_COMMANDS", "echo,grep")
4463

4564
result = await executor.execute(
46-
["echo", "hello world", "|", "grep", "world"], directory="/tmp", timeout=5
65+
["echo", "hello world", "|", "grep", "world"],
66+
directory=temp_test_dir,
67+
timeout=5,
4768
)
4869

4970
assert result["error"] is None
5071
assert result["status"] == 0
5172
assert "world" in result["stdout"]
5273
assert "execution_time" in result
53-
assert result["directory"] == "/tmp"
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_pipeline_cleanup_and_timeouts(executor, temp_test_dir, monkeypatch):
78+
"""Test cleanup of processes in pipelines and timeout handling"""
79+
clear_env(monkeypatch)
80+
monkeypatch.setenv("ALLOW_COMMANDS", "cat,tr,head,sleep")
81+
82+
# Test pipeline with early termination
83+
test_file = os.path.join(temp_test_dir, "test.txt")
84+
with open(test_file, "w") as f:
85+
f.write("test\n" * 1000)
86+
87+
result = await executor.execute(
88+
["cat", test_file, "|", "tr", "[:lower:]", "[:upper:]", "|", "head", "-n", "1"],
89+
temp_test_dir,
90+
timeout=2,
91+
)
92+
assert result["status"] == 0
93+
assert result["stdout"].strip() == "TEST"
94+
95+
# Test timeout handling in pipeline
96+
result = await executor.execute(["sleep", "5"], temp_test_dir, timeout=1)
97+
assert result["status"] == -1
98+
assert "timed out" in result["error"].lower() # タイムアウトエラーの確認

0 commit comments

Comments
 (0)