Skip to content

Commit efdad0e

Browse files
authored
User-Defined AQL functions (#37)
* Adding user-defined AQL functions * Deterministic test
1 parent fd840a1 commit efdad0e

File tree

4 files changed

+219
-2
lines changed

4 files changed

+219
-2
lines changed

arangoasync/aql.py

+118-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__all__ = ["AQL", "AQLQueryCache"]
22

33

4-
from typing import Optional
4+
from typing import Optional, cast
55

66
from arangoasync.cursor import Cursor
77
from arangoasync.errno import HTTP_NOT_FOUND
@@ -10,6 +10,9 @@
1010
AQLCacheConfigureError,
1111
AQLCacheEntriesError,
1212
AQLCachePropertiesError,
13+
AQLFunctionCreateError,
14+
AQLFunctionDeleteError,
15+
AQLFunctionListError,
1316
AQLQueryClearError,
1417
AQLQueryExecuteError,
1518
AQLQueryExplainError,
@@ -634,3 +637,117 @@ def response_handler(resp: Response) -> Jsons:
634637
return self.deserializer.loads_many(resp.raw_body)
635638

636639
return await self._executor.execute(request, response_handler)
640+
641+
async def functions(self, namespace: Optional[str] = None) -> Result[Jsons]:
642+
"""List the registered used-defined AQL functions.
643+
644+
Args:
645+
namespace (str | None): Returns all registered AQL user functions from
646+
the specified namespace.
647+
648+
Returns:
649+
list: List of the AQL functions defined in the database.
650+
651+
Raises:
652+
AQLFunctionListError: If retrieval fails.
653+
654+
References:
655+
- `list-the-registered-user-defined-aql-functions <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#list-the-registered-user-defined-aql-functions>`__
656+
""" # noqa: E501
657+
params: Json = dict()
658+
if namespace is not None:
659+
params["namespace"] = namespace
660+
request = Request(
661+
method=Method.GET,
662+
endpoint="/_api/aqlfunction",
663+
params=params,
664+
)
665+
666+
def response_handler(resp: Response) -> Jsons:
667+
if not resp.is_success:
668+
raise AQLFunctionListError(resp, request)
669+
result = cast(Jsons, self.deserializer.loads(resp.raw_body).get("result"))
670+
if result is None:
671+
raise AQLFunctionListError(resp, request)
672+
return result
673+
674+
return await self._executor.execute(request, response_handler)
675+
676+
async def create_function(
677+
self,
678+
name: str,
679+
code: str,
680+
is_deterministic: Optional[bool] = None,
681+
) -> Result[Json]:
682+
"""Registers a user-defined AQL function (UDF) written in JavaScript.
683+
684+
Args:
685+
name (str): Name of the function.
686+
code (str): JavaScript code of the function.
687+
is_deterministic (bool | None): If set to `True`, the function is
688+
deterministic.
689+
690+
Returns:
691+
dict: Information about the registered function.
692+
693+
Raises:
694+
AQLFunctionCreateError: If registration fails.
695+
696+
References:
697+
- `create-a-user-defined-aql-function <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#create-a-user-defined-aql-function>`__
698+
""" # noqa: E501
699+
request = Request(
700+
method=Method.POST,
701+
endpoint="/_api/aqlfunction",
702+
data=self.serializer.dumps(
703+
dict(name=name, code=code, isDeterministic=is_deterministic)
704+
),
705+
)
706+
707+
def response_handler(resp: Response) -> Json:
708+
if not resp.is_success:
709+
raise AQLFunctionCreateError(resp, request)
710+
return self.deserializer.loads(resp.raw_body)
711+
712+
return await self._executor.execute(request, response_handler)
713+
714+
async def delete_function(
715+
self,
716+
name: str,
717+
group: Optional[bool] = None,
718+
ignore_missing: bool = False,
719+
) -> Result[Json]:
720+
"""Remove a user-defined AQL function.
721+
722+
Args:
723+
name (str): Name of the function.
724+
group (bool | None): If set to `True`, the function name is treated
725+
as a namespace prefix.
726+
ignore_missing (bool): If set to `True`, will not raise an exception
727+
if the function is not found.
728+
729+
Returns:
730+
dict: Information about the removed functions (their count).
731+
732+
Raises:
733+
AQLFunctionDeleteError: If removal fails.
734+
735+
References:
736+
- `remove-a-user-defined-aql-function <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#remove-a-user-defined-aql-function>`__
737+
""" # noqa: E501
738+
params: Json = dict()
739+
if group is not None:
740+
params["group"] = group
741+
request = Request(
742+
method=Method.DELETE,
743+
endpoint=f"/_api/aqlfunction/{name}",
744+
params=params,
745+
)
746+
747+
def response_handler(resp: Response) -> Json:
748+
if not resp.is_success:
749+
if not (resp.status_code == HTTP_NOT_FOUND and ignore_missing):
750+
raise AQLFunctionDeleteError(resp, request)
751+
return self.deserializer.loads(resp.raw_body)
752+
753+
return await self._executor.execute(request, response_handler)

arangoasync/exceptions.py

+12
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ class AQLCachePropertiesError(ArangoServerError):
8787
"""Failed to retrieve query cache properties."""
8888

8989

90+
class AQLFunctionCreateError(ArangoServerError):
91+
"""Failed to create AQL user function."""
92+
93+
94+
class AQLFunctionDeleteError(ArangoServerError):
95+
"""Failed to delete AQL user function."""
96+
97+
98+
class AQLFunctionListError(ArangoServerError):
99+
"""Failed to retrieve AQL user functions."""
100+
101+
90102
class AQLQueryClearError(ArangoServerError):
91103
"""Failed to clear slow AQL queries."""
92104

tests/test_aql.py

+88-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
import pytest
55
from packaging import version
66

7-
from arangoasync.errno import FORBIDDEN, QUERY_PARSE
7+
from arangoasync.errno import (
8+
FORBIDDEN,
9+
QUERY_FUNCTION_INVALID_CODE,
10+
QUERY_FUNCTION_NOT_FOUND,
11+
QUERY_PARSE,
12+
)
813
from arangoasync.exceptions import (
914
AQLCacheClearError,
1015
AQLCacheConfigureError,
1116
AQLCacheEntriesError,
1217
AQLCachePropertiesError,
18+
AQLFunctionCreateError,
19+
AQLFunctionDeleteError,
20+
AQLFunctionListError,
1321
AQLQueryClearError,
1422
AQLQueryExecuteError,
1523
AQLQueryExplainError,
@@ -276,3 +284,82 @@ async def test_cache_plan_management(db, bad_db, doc_col, docs, db_version):
276284
with pytest.raises(AQLCacheClearError) as err:
277285
await bad_db.aql.cache.clear_plan()
278286
assert err.value.error_code == FORBIDDEN
287+
288+
289+
@pytest.mark.asyncio
290+
async def test_aql_function_management(db, bad_db):
291+
fn_group = "functions::temperature"
292+
fn_name_1 = "functions::temperature::celsius_to_fahrenheit"
293+
fn_body_1 = "function (celsius) { return celsius * 1.8 + 32; }"
294+
fn_name_2 = "functions::temperature::fahrenheit_to_celsius"
295+
fn_body_2 = "function (fahrenheit) { return (fahrenheit - 32) / 1.8; }"
296+
bad_fn_name = "functions::temperature::should_not_exist"
297+
bad_fn_body = "function (celsius) { invalid syntax }"
298+
299+
aql = db.aql
300+
# List AQL functions
301+
assert await aql.functions() == []
302+
303+
# List AQL functions with bad database
304+
with pytest.raises(AQLFunctionListError) as err:
305+
await bad_db.aql.functions()
306+
assert err.value.error_code == FORBIDDEN
307+
308+
# Create invalid AQL function
309+
with pytest.raises(AQLFunctionCreateError) as err:
310+
await aql.create_function(bad_fn_name, bad_fn_body)
311+
assert err.value.error_code == QUERY_FUNCTION_INVALID_CODE
312+
313+
# Create first AQL function
314+
result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True)
315+
assert result["isNewlyCreated"] is True
316+
functions = await aql.functions()
317+
assert len(functions) == 1
318+
assert functions[0]["name"] == fn_name_1
319+
assert functions[0]["code"] == fn_body_1
320+
assert functions[0]["isDeterministic"] is True
321+
322+
# Create same AQL function again
323+
result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True)
324+
assert result["isNewlyCreated"] is False
325+
functions = await aql.functions()
326+
assert len(functions) == 1
327+
assert functions[0]["name"] == fn_name_1
328+
assert functions[0]["code"] == fn_body_1
329+
assert functions[0]["isDeterministic"] is True
330+
331+
# Create second AQL function
332+
result = await aql.create_function(fn_name_2, fn_body_2, is_deterministic=False)
333+
assert result["isNewlyCreated"] is True
334+
functions = await aql.functions()
335+
assert len(functions) == 2
336+
assert functions[0]["name"] == fn_name_1
337+
assert functions[0]["code"] == fn_body_1
338+
assert functions[0]["isDeterministic"] is True
339+
assert functions[1]["name"] == fn_name_2
340+
assert functions[1]["code"] == fn_body_2
341+
assert functions[1]["isDeterministic"] is False
342+
343+
# Delete first function
344+
result = await aql.delete_function(fn_name_1)
345+
assert result["deletedCount"] == 1
346+
functions = await aql.functions()
347+
assert len(functions) == 1
348+
349+
# Delete missing function
350+
with pytest.raises(AQLFunctionDeleteError) as err:
351+
await aql.delete_function(fn_name_1)
352+
assert err.value.error_code == QUERY_FUNCTION_NOT_FOUND
353+
result = await aql.delete_function(fn_name_1, ignore_missing=True)
354+
assert "deletedCount" not in result
355+
356+
# Delete function from bad db
357+
with pytest.raises(AQLFunctionDeleteError) as err:
358+
await bad_db.aql.delete_function(fn_name_2)
359+
assert err.value.error_code == FORBIDDEN
360+
361+
# Delete function group
362+
result = await aql.delete_function(fn_group, group=True)
363+
assert result["deletedCount"] == 1
364+
functions = await aql.functions()
365+
assert len(functions) == 0

tests/test_cursor.py

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ async def test_cursor_write_query(db, doc_col, docs):
128128
cursor = await aql.execute(
129129
"""
130130
FOR d IN {col} FILTER d.val == @first OR d.val == @second
131+
SORT d.val
131132
UPDATE {{_key: d._key, _val: @val }} IN {col}
132133
RETURN NEW
133134
""".format(

0 commit comments

Comments
 (0)