Skip to content

Commit 06dd109

Browse files
committed
feat: add missing impls
1 parent 8b21264 commit 06dd109

13 files changed

+879
-48
lines changed

Diff for: src/airflow/config.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
1-
from typing import Callable
1+
from typing import Any, Dict, List, Optional, Union, Callable
2+
3+
import mcp.types as types
4+
from airflow_client.client.api.config_api import ConfigApi
5+
6+
from src.airflow.airflow_client import api_client
7+
8+
config_api = ConfigApi(api_client)
29

310

411
def get_all_functions() -> list[tuple[Callable, str, str]]:
5-
raise NotImplementedError("Not implemented")
12+
return [
13+
(get_config, "get_config", "Get current configuration"),
14+
(get_value, "get_value", "Get a specific option from configuration"),
15+
]
16+
17+
18+
async def get_config(
19+
section: Optional[str] = None,
20+
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
21+
# Build parameters dictionary
22+
kwargs: Dict[str, Any] = {}
23+
if section is not None:
24+
kwargs["section"] = section
25+
26+
response = config_api.get_config(**kwargs)
27+
return [types.TextContent(type="text", text=str(response.to_dict()))]
28+
29+
30+
async def get_value(
31+
section: str, option: str
32+
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
33+
response = config_api.get_value(section=section, option=option)
34+
return [types.TextContent(type="text", text=str(response.to_dict()))]

Diff for: src/airflow/connection.py

+30
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def get_all_functions() -> list[tuple[Callable, str, str]]:
1616
(get_connection, "get_connection", "Get a connection by ID"),
1717
(update_connection, "update_connection", "Update a connection by ID"),
1818
(delete_connection, "delete_connection", "Delete a connection by ID"),
19+
(test_connection, "test_connection", "Test a connection"),
1920
]
2021

2122

@@ -114,3 +115,32 @@ async def delete_connection(
114115
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
115116
response = connection_api.delete_connection(connection_id=conn_id)
116117
return [types.TextContent(type="text", text=str(response.to_dict()))]
118+
119+
120+
async def test_connection(
121+
conn_type: str,
122+
host: Optional[str] = None,
123+
port: Optional[int] = None,
124+
login: Optional[str] = None,
125+
password: Optional[str] = None,
126+
schema: Optional[str] = None,
127+
extra: Optional[str] = None,
128+
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
129+
connection_request = {
130+
"conn_type": conn_type,
131+
}
132+
if host is not None:
133+
connection_request["host"] = host
134+
if port is not None:
135+
connection_request["port"] = port
136+
if login is not None:
137+
connection_request["login"] = login
138+
if password is not None:
139+
connection_request["password"] = password
140+
if schema is not None:
141+
connection_request["schema"] = schema
142+
if extra is not None:
143+
connection_request["extra"] = extra
144+
145+
response = connection_api.test_connection(connection_request=connection_request)
146+
return [types.TextContent(type="text", text=str(response.to_dict()))]

Diff for: src/airflow/dag.py

+157-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import mcp.types as types
44
from airflow_client.client.api.dag_api import DAGApi
5-
5+
from airflow_client.client.model.clear_task_instances import ClearTaskInstances
6+
from airflow_client.client.model.update_task_instances_state import UpdateTaskInstancesState
7+
from airflow_client.client.model.dag import DAG
68
from src.airflow.airflow_client import api_client
79
from src.envs import AIRFLOW_HOST
810

@@ -11,18 +13,29 @@
1113

1214
def get_all_functions() -> list[tuple[Callable, str, str]]:
1315
return [
14-
(fetch_dags, "fetch_dags", "Fetch all DAGs"),
16+
(get_dags, "fetch_dags", "Fetch all DAGs"),
1517
(get_dag, "get_dag", "Get a DAG by ID"),
18+
(get_dag_details, "get_dag_details", "Get a simplified representation of DAG"),
19+
(get_dag_source, "get_dag_source", "Get a source code"),
1620
(pause_dag, "pause_dag", "Pause a DAG by ID"),
1721
(unpause_dag, "unpause_dag", "Unpause a DAG by ID"),
22+
(get_dag_tasks, "get_dag_tasks", "Get tasks for DAG"),
23+
(get_task, "get_task", "Get a task by ID"),
24+
(get_tasks, "get_tasks", "Get tasks for DAG"),
25+
(patch_dag, "patch_dag", "Update a DAG"),
26+
(patch_dags, "patch_dags", "Update multiple DAGs"),
27+
(delete_dag, "delete_dag", "Delete a DAG"),
28+
(clear_task_instances, "clear_task_instances", "Clear a set of task instances"),
29+
(set_task_instances_state, "set_task_instances_state", "Set a state of task instances"),
30+
(reparse_dag_file, "reparse_dag_file", "Request re-parsing of a DAG file"),
1831
]
1932

2033

2134
def get_dag_url(dag_id: str) -> str:
2235
return f"{AIRFLOW_HOST}/dags/{dag_id}/grid"
2336

2437

25-
async def fetch_dags(
38+
async def get_dags(
2639
limit: Optional[int] = None,
2740
offset: Optional[int] = None,
2841
order_by: Optional[str] = None,
@@ -73,6 +86,21 @@ async def get_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageConte
7386
return [types.TextContent(type="text", text=str(response_dict))]
7487

7588

89+
async def get_dag_details(dag_id: str, fields: Optional[List[str]] = None) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
90+
# Build parameters dictionary
91+
kwargs: Dict[str, Any] = {}
92+
if fields is not None:
93+
kwargs["fields"] = fields
94+
95+
response = dag_api.get_dag_details(dag_id=dag_id, **kwargs)
96+
return [types.TextContent(type="text", text=str(response.to_dict()))]
97+
98+
99+
async def get_dag_source(file_token: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
100+
response = dag_api.get_dag_source(file_token=file_token)
101+
return [types.TextContent(type="text", text=str(response.to_dict()))]
102+
103+
76104
async def pause_dag(dag_id: str) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
77105
response = dag_api.patch_dag(dag_id=dag_id, dag_update_request={"is_paused": True})
78106
return [types.TextContent(type="text", text=str(response.to_dict()))]
@@ -88,16 +116,45 @@ async def get_dag_tasks(dag_id: str) -> List[Union[types.TextContent, types.Imag
88116
return [types.TextContent(type="text", text=str(response.to_dict()))]
89117

90118

91-
async def update_dag(
119+
async def patch_dag(
92120
dag_id: str, is_paused: Optional[bool] = None, tags: Optional[List[str]] = None
93121
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
94122
update_request = {}
123+
update_mask = []
124+
95125
if is_paused is not None:
96126
update_request["is_paused"] = is_paused
127+
update_mask.append("is_paused")
97128
if tags is not None:
98129
update_request["tags"] = tags
130+
update_mask.append("tags")
131+
132+
dag = DAG(**update_request)
133+
134+
response = dag_api.patch_dag(dag_id=dag_id, dag=dag, update_mask=update_mask)
135+
return [types.TextContent(type="text", text=str(response.to_dict()))]
136+
137+
138+
async def patch_dags(
139+
dag_id_pattern: Optional[str] = None, is_paused: Optional[bool] = None, tags: Optional[List[str]] = None,
140+
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
141+
update_request = {}
142+
update_mask = []
99143

100-
response = dag_api.patch_dag(dag_id=dag_id, dag_update_request=update_request)
144+
if is_paused is not None:
145+
update_request["is_paused"] = is_paused
146+
update_mask.append("is_paused")
147+
if tags is not None:
148+
update_request["tags"] = tags
149+
update_mask.append("tags")
150+
151+
dag = DAG(**update_request)
152+
153+
kwargs = {}
154+
if dag_id_pattern is not None:
155+
kwargs["dag_id_pattern"] = dag_id_pattern
156+
157+
response = dag_api.patch_dags(dag_id_pattern=dag_id_pattern, dag=dag, update_mask=update_mask, **kwargs)
101158
return [types.TextContent(type="text", text=str(response.to_dict()))]
102159

103160

@@ -111,3 +168,98 @@ async def get_task(
111168
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
112169
response = dag_api.get_task(dag_id=dag_id, task_id=task_id)
113170
return [types.TextContent(type="text", text=str(response.to_dict()))]
171+
172+
173+
async def get_tasks(dag_id: str, order_by: Optional[str] = None) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
174+
kwargs = {}
175+
if order_by is not None:
176+
kwargs["order_by"] = order_by
177+
178+
response = dag_api.get_tasks(dag_id=dag_id, **kwargs)
179+
return [types.TextContent(type="text", text=str(response.to_dict()))]
180+
181+
182+
async def clear_task_instances(
183+
dag_id: str,
184+
task_ids: Optional[List[str]] = None,
185+
start_date: Optional[str] = None,
186+
end_date: Optional[str] = None,
187+
include_subdags: Optional[bool] = None,
188+
include_parentdag: Optional[bool] = None,
189+
include_upstream: Optional[bool] = None,
190+
include_downstream: Optional[bool] = None,
191+
include_future: Optional[bool] = None,
192+
include_past: Optional[bool] = None,
193+
dry_run: Optional[bool] = None,
194+
reset_dag_runs: Optional[bool] = None,
195+
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
196+
clear_request = {}
197+
if task_ids is not None:
198+
clear_request["task_ids"] = task_ids
199+
if start_date is not None:
200+
clear_request["start_date"] = start_date
201+
if end_date is not None:
202+
clear_request["end_date"] = end_date
203+
if include_subdags is not None:
204+
clear_request["include_subdags"] = include_subdags
205+
if include_parentdag is not None:
206+
clear_request["include_parentdag"] = include_parentdag
207+
if include_upstream is not None:
208+
clear_request["include_upstream"] = include_upstream
209+
if include_downstream is not None:
210+
clear_request["include_downstream"] = include_downstream
211+
if include_future is not None:
212+
clear_request["include_future"] = include_future
213+
if include_past is not None:
214+
clear_request["include_past"] = include_past
215+
if dry_run is not None:
216+
clear_request["dry_run"] = dry_run
217+
if reset_dag_runs is not None:
218+
clear_request["reset_dag_runs"] = reset_dag_runs
219+
220+
clear_task_instances = ClearTaskInstances(**clear_request)
221+
222+
response = dag_api.post_clear_task_instances(dag_id=dag_id, clear_task_instances=clear_task_instances)
223+
return [types.TextContent(type="text", text=str(response.to_dict()))]
224+
225+
226+
async def set_task_instances_state(
227+
dag_id: str,
228+
state: str,
229+
task_ids: Optional[List[str]] = None,
230+
execution_date: Optional[str] = None,
231+
include_upstream: Optional[bool] = None,
232+
include_downstream: Optional[bool] = None,
233+
include_future: Optional[bool] = None,
234+
include_past: Optional[bool] = None,
235+
dry_run: Optional[bool] = None,
236+
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
237+
state_request = {
238+
"state": state
239+
}
240+
if task_ids is not None:
241+
state_request["task_ids"] = task_ids
242+
if execution_date is not None:
243+
state_request["execution_date"] = execution_date
244+
if include_upstream is not None:
245+
state_request["include_upstream"] = include_upstream
246+
if include_downstream is not None:
247+
state_request["include_downstream"] = include_downstream
248+
if include_future is not None:
249+
state_request["include_future"] = include_future
250+
if include_past is not None:
251+
state_request["include_past"] = include_past
252+
if dry_run is not None:
253+
state_request["dry_run"] = dry_run
254+
255+
update_task_instances_state = UpdateTaskInstancesState(**state_request)
256+
257+
response = dag_api.post_set_task_instances_state(dag_id=dag_id, update_task_instances_state=update_task_instances_state)
258+
return [types.TextContent(type="text", text=str(response.to_dict()))]
259+
260+
261+
async def reparse_dag_file(
262+
file_token: str
263+
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
264+
response = dag_api.reparse_dag_file(file_token=file_token)
265+
return [types.TextContent(type="text", text=str(response.to_dict()))]

0 commit comments

Comments
 (0)