Skip to content

Implementation of intersection algorithm for regular grammars #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 54 additions & 51 deletions src/problems/MultipleSource/algo/matrix_bfs/intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pygraphblas.types import BOOL
from pygraphblas.matrix import Matrix
from pygraphblas.vector import Vector
from pygraphblas import semiring, descriptor
from pygraphblas import descriptor
from pygraphblas import Accum, binaryop

from src.graph.graph import Graph
from src.problems.MultipleSource.algo.matrix_bfs.reg_automaton import RegAutomaton
Expand Down Expand Up @@ -106,7 +107,17 @@ def create_diag_matrices(self) -> Dict[str, Matrix]:

return diag_matrices

def intersect_bfs(self) -> EpsilonNFA:
def create_masks_matrix(self) -> Matrix:
num_vert_graph = self.graph.get_number_of_vertices()
num_vert_regex = self.regular_automaton.num_states
num_verts_diag = num_vert_graph + num_vert_regex

mask_matrix = Matrix.identity(BOOL, num_vert_regex, value=True)
mask_matrix.resize(num_vert_regex, num_verts_diag)

return mask_matrix

def intersect_bfs(self, src_verts) -> EpsilonNFA:
"""
Intersection implementation with synchronous breadth first traversal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this function solves reachability problem, not automata intersection. Isn't it?

of a graph and regular grammar represented in automata
Expand All @@ -123,75 +134,67 @@ def intersect_bfs(self) -> EpsilonNFA:
regex_start_states = self.regular_automaton.start_states

diag_matrices = self.create_diag_matrices()
graph_start_states = [0] # should be an argument

result = Matrix.sparse(BOOL, num_vert_graph, num_vert_graph)

# create a mask of source vertices vector
m_src_v = Vector.from_lists(src_verts, [True for _ in range(len(src_verts))], size=num_vert_graph)

# initialize matrices for multiple source bfs
found = Matrix.sparse(BOOL, len(regex_start_states), num_verts_diag)
vect = Matrix.sparse(BOOL, len(regex_start_states), num_verts_diag)
ident = self.create_masks_matrix()
vect = ident.dup()
found = ident.dup()

# fill start states
for start_state in self.regular_automaton.start_states:
found[
start_state % len(regex_start_states),
start_state,
] = True
for start_state in graph_start_states:
found[
len(regex_start_states) - 1 + start_state,
num_vert_regex + start_state,
] = True

# initialize matrix which stores front nodes found on each iteration for every symbol
iter_found = found.dup()
for reg_start_state in regex_start_states:
for gr_start_state in src_verts:
found[reg_start_state, num_vert_regex + gr_start_state] = True

# matrix which contains newly found nodes on each iteration
found_on_iter = found.dup()

# Algo's body
not_empty = True
level = 0

while not_empty and level < num_verts_inter:
# for each symbol we are going to store if any new nodes were found during traversal
# for each symbol we are going to store if any new nodes were found during traversal.
# if none are found, then 'not_empty' flag turns False, which means that no matrices change anymore
# and we can stop the traversal
not_empty_for_at_least_one_symbol = False

vect.assign_scalar(True, mask=iter_found)
vect.assign_matrix(found_on_iter, mask=vect, desc=descriptor.RC)
vect.assign_scalar(True, mask=ident)

# stores found nodes for each symbol
found_on_iter.assign_matrix(ident)

for symbol in regex:
if symbol in graph:
with semiring.ANY_PAIR_BOOL:
found = vect.mxm(
diag_matrices[symbol], mask=vect, desc=descriptor.RC
)

# append newly found nodes
iter_found.assign_scalar(True, mask=found, desc=descriptor.S)
with BOOL.ANY_PAIR:
found = vect.mxm(diag_matrices[symbol])

# the problem now is that I'm not sure how to store bfs' front edges
# of intersection automata, since 'found' matrix doesn't store
# information about which pair of nodes is in front at the moment

# TODO: then here I should assign matrix elements to True
with Accum(binaryop.MAX_BOOL):
# extract left (grammar) part of the masks matrix and rearrange rows
i_x, i_y, _ = found.extract_matrix(col_index=slice(0, num_vert_regex - 1)).to_lists()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can extract_matrix be replaced with Python array slicing?

for i in range(len(i_y)):
found_on_iter.assign_row(i_y[i], found.extract_row(i_x[i]))

# check if new nodes were found. if positive, switch the flag
if found.reduce_bool():
if not found_on_iter.iseq(vect):
not_empty_for_at_least_one_symbol = True

not_empty = not_empty_for_at_least_one_symbol
level += 1
# extract right (graph) part of the masks matrix and get a row of reachable nodes in a graph
reachable = found_on_iter.extract_matrix(
col_index=slice(num_vert_regex, num_verts_diag - 1)
).T.reduce_vector(BOOL.ANY_MONOID) # reduce by columns

return self.__to_automaton__()
# update graph boolean matrix for every source vertex
# result matrix contains reachability for every symbol combined
with Accum(binaryop.MAX_BOOL):
for st_v in src_verts:
result.assign_row(st_v, reachable, mask=m_src_v, desc=descriptor.C)

# For testing purposes
def intersect_kron(self) -> EpsilonNFA:
"""
Intersection implementation with kronecker product
"""
regex = self.regular_automaton.matrices
graph = self.graph

for symbol in regex:
if symbol in graph:
self.intersection_matrices[symbol] = regex[symbol].kronecker(
graph[symbol]
)
not_empty = not_empty_for_at_least_one_symbol
level += 1

return self.__to_automaton__()
return result
73 changes: 20 additions & 53 deletions test/MultipleSource/test_bfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,11 @@ def test_case_regular_cycle():
)

intersection = Intersection(graph, grammar)
intersect_bfs = intersection.intersect_bfs()
intersect_kron = intersection.intersect_kron()

source_verts = [0]
result = intersection.intersect_bfs(source_verts)

assert intersect_bfs.accepts(["a", "a"])
assert intersect_bfs.accepts(["a", "a", "a"])

assert not intersect_bfs.accepts(["a", "b"])
assert not intersect_bfs.accepts(["b"])

assert intersect_kron.is_equivalent_to(intersect_bfs)
assert result.nvals == 2 * len(source_verts)


@pytest.mark.CI
Expand All @@ -40,21 +35,11 @@ def test_case_regular_disconnected():
)

intersection = Intersection(graph, grammar)
intersect_bfs = intersection.intersect_bfs()
intersect_kron = intersection.intersect_kron()

assert intersect_bfs.accepts(["a", "b"])
assert intersect_bfs.accepts(["b", "a"])
assert intersect_bfs.accepts(["a", "a", "b"])
assert intersect_bfs.accepts(["a", "b", "a"])
assert intersect_bfs.accepts(["b", "a", "b"])

source_verts = [0, 3]
result = intersection.intersect_bfs(source_verts)

assert not intersect_bfs.accepts(["a"])
assert not intersect_bfs.accepts(["c"])
assert not intersect_bfs.accepts(["c", "b"])
assert not intersect_bfs.accepts(["c", "a"])

assert intersect_kron.is_equivalent_to(intersect_bfs)
assert result.nvals == 2 * len(source_verts)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check not only size of result, but also its content?



@pytest.mark.CI
Expand All @@ -67,17 +52,11 @@ def test_case_regular_loop():
)

intersection = Intersection(graph, grammar)
intersect_bfs = intersection.intersect_bfs()
intersect_kron = intersection.intersect_kron()

assert intersect_bfs.accepts(["a"])
assert intersect_bfs.accepts(["a" for _ in range(10)])

source_verts = [0, 2]
result = intersection.intersect_bfs(source_verts)

assert not intersect_bfs.accepts(["b"])
assert not intersect_bfs.accepts(["c"])
assert not intersect_bfs.accepts(["epsilon"])

assert intersect_kron.is_equivalent_to(intersect_bfs)
assert result.nvals == 0 * len(source_verts)


@pytest.mark.CI
Expand All @@ -90,18 +69,11 @@ def test_case_regular_midsymbol():
)

intersection = Intersection(graph, grammar)
intersect_bfs = intersection.intersect_bfs()
intersect_kron = intersection.intersect_kron()

assert intersect_bfs.accepts(["b"])
assert intersect_bfs.accepts(["a", "b", "c"])
assert intersect_bfs.accepts(["a", "a", "b", "c", "c"])

assert not intersect_bfs.accepts(["a"])
assert not intersect_bfs.accepts(["c"])
assert not intersect_bfs.accepts(["b", "b"])

source_verts = [0]
result = intersection.intersect_bfs(source_verts)

assert intersect_kron.is_equivalent_to(intersect_bfs)
assert result.nvals == 1 * len(source_verts)


@pytest.mark.CI
Expand All @@ -114,13 +86,8 @@ def test_case_regular_two_cycles():
)

intersection = Intersection(graph, grammar)
intersect_bfs = intersection.intersect_bfs()
intersect_kron = intersection.intersect_kron()

assert intersect_bfs.accepts(["a"])
assert intersect_bfs.accepts(["a", "a"])

assert not intersect_bfs.accepts(["b"])
assert not intersect_bfs.accepts(["c"])

source_verts = [0, 3]
result = intersection.intersect_bfs(source_verts)

assert intersect_kron.is_equivalent_to(intersect_bfs)
assert result.nvals == 2 * len(source_verts)