1
+ import os
1
2
import sys
2
3
from contextlib import asynccontextmanager
3
4
9
10
10
11
from mcp_python .types import JSONRPCMessage
11
12
13
+ # Environment variables to inherit by default
14
+ DEFAULT_INHERITED_ENV_VARS = (
15
+ ["APPDATA" , "HOMEDRIVE" , "HOMEPATH" , "LOCALAPPDATA" , "PATH" ,
16
+ "PROCESSOR_ARCHITECTURE" , "SYSTEMDRIVE" , "SYSTEMROOT" , "TEMP" ,
17
+ "USERNAME" , "USERPROFILE" ]
18
+ if sys .platform == "win32"
19
+ else ["HOME" , "LOGNAME" , "PATH" , "SHELL" , "TERM" , "USER" ]
20
+ )
21
+
22
+
23
+ def get_default_environment () -> dict [str , str ]:
24
+ """
25
+ Returns a default environment object including only environment variables deemed
26
+ safe to inherit.
27
+ """
28
+ env : dict [str , str ] = {}
29
+
30
+ for key in DEFAULT_INHERITED_ENV_VARS :
31
+ value = os .environ .get (key )
32
+ if value is None :
33
+ continue
34
+
35
+ if value .startswith ("()" ):
36
+ # Skip functions, which are a security risk
37
+ continue
38
+
39
+ env [key ] = value
40
+
41
+ return env
42
+
12
43
13
44
class StdioServerParameters (BaseModel ):
14
45
command : str
@@ -17,11 +48,11 @@ class StdioServerParameters(BaseModel):
17
48
args : list [str ] = Field (default_factory = list )
18
49
"""Command line arguments to pass to the executable."""
19
50
20
- env : dict [str , str ] = Field ( default_factory = dict )
51
+ env : dict [str , str ] | None = None
21
52
"""
22
53
The environment to use when spawning the process.
23
54
24
- The environment is NOT inherited from the parent process by default .
55
+ If not specified, the result of get_default_environment() will be used .
25
56
"""
26
57
27
58
@@ -41,7 +72,9 @@ async def stdio_client(server: StdioServerParameters):
41
72
write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
42
73
43
74
process = await anyio .open_process (
44
- [server .command , * server .args ], env = server .env , stderr = sys .stderr
75
+ [server .command , * server .args ],
76
+ env = server .env if server .env is not None else get_default_environment (),
77
+ stderr = sys .stderr
45
78
)
46
79
47
80
async def stdout_reader ():
0 commit comments