Skip to content
This repository was archived by the owner on Apr 14, 2022. It is now read-only.

expressions: add constant propagation pass #271

Merged
merged 5 commits into from
Feb 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ lint:
graphql/core/rules.lua \
graphql/core/validate_variables.lua \
graphql/convert_schema/*.lua \
graphql/expressions/*.lua \
graphql/server/*.lua \
test/bench/*.lua \
test/space/*.lua \
Expand Down
7 changes: 7 additions & 0 deletions graphql/accessor_general.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ local db_schema_helpers = require('graphql.db_schema_helpers')
local error_codes = require('graphql.error_codes')
local statistics = require('graphql.statistics')
local expressions = require('graphql.expressions')
local constant_propagation = require('graphql.expressions.constant_propagation')
local find_index = require('graphql.find_index')

local check = utils.check
Expand Down Expand Up @@ -500,6 +501,12 @@ local function prepare_select_internal(self, collection_name, from, filter,
expr = expressions.new(expr)
end

-- propagate constants in the expression
if expr ~= nil then
expr = constant_propagation.transform(expr,
{variables = qcontext.variables})
end

-- read only process_tuple options
local select_opts = {
limit = args.limit,
Expand Down
222 changes: 84 additions & 138 deletions graphql/expressions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,18 @@ local unary_minus = P('-')
local unary_plus = P('+')

-- Possible binary operator patterns:
-- 1) Logical and.
local logic_and = P('&&')
-- 2) logical or.
local logic_or = P('||')
-- 3) +
local addition = P('+')
-- 4) -
local subtraction = P('-')
-- 5) ==
local multiplication = P('*')
local division = P('/')
local modulo = P('%')
local eq = P('==')
-- 6) !=
local not_eq = P('!=')
-- 7) >
local gt = P('>')
-- 8) >=
local ge = P('>=')
-- 9) <
local lt = P('<')
-- 10) <=
local le = P('<=')

-- AST nodes generating functions.
Expand All @@ -79,32 +72,19 @@ local identical_node = identical

local op_name = identical

local function root_expr_node(expr)
return {
kind = 'root_expression',
expr = expr
}
end

-- left associativity
local function bin_op_node(...)
if select('#', ...) == 1 then
local args_cnt = select('#', ...)
assert(args_cnt % 2 == 1)
if args_cnt == 1 then
return select(1, ...)
end
local operators = {}
local operands = {}
for i = 1, select('#', ...) do
local v = select(i, ...)
if i % 2 == 0 then
table.insert(operators, v)
else
table.insert(operands, v)
end
end
return {
kind = 'binary_operations',
operators = operators,
operands = operands
}
return bin_op_node({
kind = 'binary_operation',
op = select(2, ...),
left = select(1, ...),
right = select(3, ...),
}, select(4, ...))
end

local function unary_op_node(unary_operator, operand_1)
Expand Down Expand Up @@ -174,25 +154,29 @@ local _literal = _bool + _number + _string
local _logic_or = logic_or / op_name
local _logic_and = logic_and / op_name
local _comparison_op = (eq + not_eq + ge + gt + le + lt) / op_name
local _arithmetic_op = (addition + subtraction) / op_name
local _arithmetic_sum_op = (addition + subtraction) / op_name
local _arithmetic_mul_op = (multiplication + division + modulo) / op_name
local _unary_op = (negation + unary_minus + unary_plus) / op_name
local _functions = (is_null + is_not_null + regexp) / identical

-- Grammar rules for C-style expressions positioned ascending in
-- terms of priority.
local expression_grammar = P {
'init_expr',
init_expr = V('expr') * eof / root_expr_node,
init_expr = V('expr') * eof / identical_node,
expr = spaces * V('log_expr_or') * spaces / identical_node,

log_expr_or = V('log_expr_and') * (spaces * _logic_or *
spaces * V('log_expr_and')) ^ 0 / bin_op_node,
log_expr_and = V('comparison') * (spaces * _logic_and * spaces *
V('comparison')) ^ 0 / bin_op_node,
comparison = V('arithmetic_expr') * (spaces * _comparison_op * spaces *
V('arithmetic_expr')) ^ 0 / bin_op_node,
arithmetic_expr = V('unary_expr') * (spaces * _arithmetic_op * spaces *
V('unary_expr')) ^ 0 / bin_op_node,
comparison = V('arithmetic_sum_expr') * (spaces * _comparison_op * spaces *
V('arithmetic_sum_expr')) ^ 0 / bin_op_node,
arithmetic_sum_expr = V('arithmetic_mul_expr') * (spaces *
_arithmetic_sum_op * spaces *
V('arithmetic_mul_expr')) ^ 0 / bin_op_node,
arithmetic_mul_expr = V('unary_expr') * (spaces * _arithmetic_mul_op *
spaces * V('unary_expr')) ^ 0 / bin_op_node,

unary_expr = (_unary_op * V('first_prio') / unary_op_node) +
(V('first_prio') / identical_node),
Expand Down Expand Up @@ -250,130 +234,90 @@ local function execute_node(node, context)
if node.kind == 'const' then
if node.value_class == 'string' then
return node.value
end

if node.value_class == 'bool' then
elseif node.value_class == 'bool' then
if node.value == 'false' then
return false
elseif node.value == 'true' then
return true
else
error('Unknown boolean node value: ' .. tostring(node.value))
end
return true
end

if node.value_class == 'number' then
elseif node.value_class == 'number' then
return tonumber(node.value)
else
error('Unknown const class: ' .. tostring(node.value_class))
end
end

if node.kind == 'variable' then
elseif node.kind == 'variable' then
local name = node.name
return context.variables[name]
end

if node.kind == 'object_field' then
elseif node.kind == 'object_field' then
local path = node.path
local field = context.object
local table_path = (path:split('.'))
local table_path = path:split('.')
for i = 1, #table_path do
field = field[table_path[i]]
end
return field
end

if node.kind == 'func' then
-- regexp() implementation.
elseif node.kind == 'func' then
if node.name == 'regexp' then
return utils.regexp(execute_node(node.args[1], context),
execute_node(node.args[2], context))
end

-- is_null() implementation.
if node.name == 'is_null' then
elseif node.name == 'is_null' then
return execute_node(node.args[1], context) == nil
end

-- is_not_null() implementation.
if node.name == 'is_not_null' then
elseif node.name == 'is_not_null' then
return execute_node(node.args[1], context) ~= nil
else
error('Unknown func name: ' .. tostring(node.name))
end
end

if node.kind == 'unary_operation' then
-- Negation.
elseif node.kind == 'unary_operation' then
if node.op == '!' then
return not execute_node(node.node, context)
end

-- Unary '+'.
if node.op == '+' then
elseif node.op == '+' then
return execute_node(node.node, context)
end

-- Unary '-'.
if node.op == '-' then
elseif node.op == '-' then
return -execute_node(node.node, context)
else
error('Unknown unary operation: ' .. tostring(node.op))
end
end

if node.kind == 'binary_operations' then
local prev = execute_node(node.operands[1], context)
for i, op in ipairs(node.operators) do
local second_operand = execute_node(node.operands[i + 1],
context)
-- Sum.
if op == '+' then
prev = sum(prev, second_operand)
end

-- Subtraction.
if op == '-' then
prev = subtract(prev, second_operand)
end

-- Logical and.
if op == '&&' then
prev = prev and second_operand
end

-- Logical or.
if op == '||' then
prev = prev or second_operand
end

-- Equal.
if op == '==' then
prev = prev == second_operand
end

-- Not equal.
if op == '!=' then
prev = prev ~= second_operand
end

-- Greater than.
if op == '>' then
prev = prev > second_operand
end

-- Greater or equal.
if op == '>=' then
prev = prev >= second_operand
end

-- Lower than.
if op == '<' then
prev = prev < second_operand
end

-- Lower or equal.
if op == '<=' then
prev = prev <= second_operand
end
elseif node.kind == 'binary_operation' then
local op = node.op
local left = execute_node(node.left, context)
local right = execute_node(node.right, context)

if op == '+' then
return sum(left, right)
elseif op == '-' then
return subtract(left, right)
elseif op == '*' then
return left * right
elseif op == '/' then
return left / right
elseif op == '%' then
return left % right
elseif op == '&&' then
return left and right
elseif op == '||' then
return left or right
elseif op == '==' then
return left == right
elseif op == '!=' then
return left ~= right
elseif op == '>' then
return left > right
elseif op == '>=' then
return left >= right
elseif op == '<' then
return left < right
elseif op == '<=' then
return left <= right
else
error('Unknown binary operation: ' .. tostring(op))
end
return prev
end

if node.kind == 'root_expression' then
return execute_node(node.expr, context)
elseif node.kind == 'evaluated' then
-- evaluated node can occurs after optimizations
return node.value
else
error('Unknown node kind: ' .. tostring(node.kind))
end
end

Expand Down Expand Up @@ -417,4 +361,6 @@ function expressions.new(str)
})
end

expressions.execute_node = execute_node

return expressions
Loading