-
Notifications
You must be signed in to change notification settings - Fork 12
Implementation of intersection algorithm for regular grammars #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6aac117
Implementation of intersection algo for regular grammar
bahbyega 5d07997
Add missing marks
bahbyega f6871b0
Moved algo to the right location, changed it to be multiple source
bahbyega 7b397e7
New approach of traversing
bahbyega 729e461
MSMatrixBfsAlgo updated
bahbyega 17fee55
Updated tests for MSMatrixBfsAlgo
bahbyega File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,8 @@ | |
from pygraphblas.types import BOOL | ||
from pygraphblas.matrix import Matrix | ||
from pygraphblas.vector import Vector | ||
from pygraphblas import semiring, descriptor | ||
from pygraphblas import descriptor | ||
from pygraphblas import Accum, binaryop | ||
|
||
from src.graph.graph import Graph | ||
from src.problems.MultipleSource.algo.matrix_bfs.reg_automaton import RegAutomaton | ||
|
@@ -106,7 +107,17 @@ def create_diag_matrices(self) -> Dict[str, Matrix]: | |
|
||
return diag_matrices | ||
|
||
def intersect_bfs(self) -> EpsilonNFA: | ||
def create_masks_matrix(self) -> Matrix: | ||
num_vert_graph = self.graph.get_number_of_vertices() | ||
num_vert_regex = self.regular_automaton.num_states | ||
num_verts_diag = num_vert_graph + num_vert_regex | ||
|
||
mask_matrix = Matrix.identity(BOOL, num_vert_regex, value=True) | ||
mask_matrix.resize(num_vert_regex, num_verts_diag) | ||
|
||
return mask_matrix | ||
|
||
def intersect_bfs(self, src_verts) -> EpsilonNFA: | ||
""" | ||
Intersection implementation with synchronous breadth first traversal | ||
of a graph and regular grammar represented in automata | ||
|
@@ -123,75 +134,67 @@ def intersect_bfs(self) -> EpsilonNFA: | |
regex_start_states = self.regular_automaton.start_states | ||
|
||
diag_matrices = self.create_diag_matrices() | ||
graph_start_states = [0] # should be an argument | ||
|
||
result = Matrix.sparse(BOOL, num_vert_graph, num_vert_graph) | ||
|
||
# create a mask of source vertices vector | ||
m_src_v = Vector.from_lists(src_verts, [True for _ in range(len(src_verts))], size=num_vert_graph) | ||
|
||
# initialize matrices for multiple source bfs | ||
found = Matrix.sparse(BOOL, len(regex_start_states), num_verts_diag) | ||
vect = Matrix.sparse(BOOL, len(regex_start_states), num_verts_diag) | ||
ident = self.create_masks_matrix() | ||
vect = ident.dup() | ||
found = ident.dup() | ||
|
||
# fill start states | ||
for start_state in self.regular_automaton.start_states: | ||
found[ | ||
start_state % len(regex_start_states), | ||
start_state, | ||
] = True | ||
for start_state in graph_start_states: | ||
found[ | ||
len(regex_start_states) - 1 + start_state, | ||
num_vert_regex + start_state, | ||
] = True | ||
|
||
# initialize matrix which stores front nodes found on each iteration for every symbol | ||
iter_found = found.dup() | ||
for reg_start_state in regex_start_states: | ||
for gr_start_state in src_verts: | ||
found[reg_start_state, num_vert_regex + gr_start_state] = True | ||
|
||
# matrix which contains newly found nodes on each iteration | ||
found_on_iter = found.dup() | ||
|
||
# Algo's body | ||
not_empty = True | ||
level = 0 | ||
|
||
while not_empty and level < num_verts_inter: | ||
# for each symbol we are going to store if any new nodes were found during traversal | ||
# for each symbol we are going to store if any new nodes were found during traversal. | ||
# if none are found, then 'not_empty' flag turns False, which means that no matrices change anymore | ||
# and we can stop the traversal | ||
not_empty_for_at_least_one_symbol = False | ||
|
||
vect.assign_scalar(True, mask=iter_found) | ||
vect.assign_matrix(found_on_iter, mask=vect, desc=descriptor.RC) | ||
vect.assign_scalar(True, mask=ident) | ||
|
||
# stores found nodes for each symbol | ||
found_on_iter.assign_matrix(ident) | ||
|
||
for symbol in regex: | ||
if symbol in graph: | ||
with semiring.ANY_PAIR_BOOL: | ||
found = vect.mxm( | ||
diag_matrices[symbol], mask=vect, desc=descriptor.RC | ||
) | ||
|
||
# append newly found nodes | ||
iter_found.assign_scalar(True, mask=found, desc=descriptor.S) | ||
with BOOL.ANY_PAIR: | ||
found = vect.mxm(diag_matrices[symbol]) | ||
|
||
# the problem now is that I'm not sure how to store bfs' front edges | ||
# of intersection automata, since 'found' matrix doesn't store | ||
# information about which pair of nodes is in front at the moment | ||
|
||
# TODO: then here I should assign matrix elements to True | ||
with Accum(binaryop.MAX_BOOL): | ||
# extract left (grammar) part of the masks matrix and rearrange rows | ||
i_x, i_y, _ = found.extract_matrix(col_index=slice(0, num_vert_regex - 1)).to_lists() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can |
||
for i in range(len(i_y)): | ||
found_on_iter.assign_row(i_y[i], found.extract_row(i_x[i])) | ||
|
||
# check if new nodes were found. if positive, switch the flag | ||
if found.reduce_bool(): | ||
if not found_on_iter.iseq(vect): | ||
not_empty_for_at_least_one_symbol = True | ||
|
||
not_empty = not_empty_for_at_least_one_symbol | ||
level += 1 | ||
# extract right (graph) part of the masks matrix and get a row of reachable nodes in a graph | ||
reachable = found_on_iter.extract_matrix( | ||
col_index=slice(num_vert_regex, num_verts_diag - 1) | ||
).T.reduce_vector(BOOL.ANY_MONOID) # reduce by columns | ||
|
||
return self.__to_automaton__() | ||
# update graph boolean matrix for every source vertex | ||
# result matrix contains reachability for every symbol combined | ||
with Accum(binaryop.MAX_BOOL): | ||
for st_v in src_verts: | ||
result.assign_row(st_v, reachable, mask=m_src_v, desc=descriptor.C) | ||
|
||
# For testing purposes | ||
def intersect_kron(self) -> EpsilonNFA: | ||
""" | ||
Intersection implementation with kronecker product | ||
""" | ||
regex = self.regular_automaton.matrices | ||
graph = self.graph | ||
|
||
for symbol in regex: | ||
if symbol in graph: | ||
self.intersection_matrices[symbol] = regex[symbol].kronecker( | ||
graph[symbol] | ||
) | ||
not_empty = not_empty_for_at_least_one_symbol | ||
level += 1 | ||
|
||
return self.__to_automaton__() | ||
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,16 +18,11 @@ def test_case_regular_cycle(): | |
) | ||
|
||
intersection = Intersection(graph, grammar) | ||
intersect_bfs = intersection.intersect_bfs() | ||
intersect_kron = intersection.intersect_kron() | ||
|
||
source_verts = [0] | ||
result = intersection.intersect_bfs(source_verts) | ||
|
||
assert intersect_bfs.accepts(["a", "a"]) | ||
assert intersect_bfs.accepts(["a", "a", "a"]) | ||
|
||
assert not intersect_bfs.accepts(["a", "b"]) | ||
assert not intersect_bfs.accepts(["b"]) | ||
|
||
assert intersect_kron.is_equivalent_to(intersect_bfs) | ||
assert result.nvals == 2 * len(source_verts) | ||
|
||
|
||
@pytest.mark.CI | ||
|
@@ -40,21 +35,11 @@ def test_case_regular_disconnected(): | |
) | ||
|
||
intersection = Intersection(graph, grammar) | ||
intersect_bfs = intersection.intersect_bfs() | ||
intersect_kron = intersection.intersect_kron() | ||
|
||
assert intersect_bfs.accepts(["a", "b"]) | ||
assert intersect_bfs.accepts(["b", "a"]) | ||
assert intersect_bfs.accepts(["a", "a", "b"]) | ||
assert intersect_bfs.accepts(["a", "b", "a"]) | ||
assert intersect_bfs.accepts(["b", "a", "b"]) | ||
|
||
source_verts = [0, 3] | ||
result = intersection.intersect_bfs(source_verts) | ||
|
||
assert not intersect_bfs.accepts(["a"]) | ||
assert not intersect_bfs.accepts(["c"]) | ||
assert not intersect_bfs.accepts(["c", "b"]) | ||
assert not intersect_bfs.accepts(["c", "a"]) | ||
|
||
assert intersect_kron.is_equivalent_to(intersect_bfs) | ||
assert result.nvals == 2 * len(source_verts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we check not only size of result, but also its content? |
||
|
||
|
||
@pytest.mark.CI | ||
|
@@ -67,17 +52,11 @@ def test_case_regular_loop(): | |
) | ||
|
||
intersection = Intersection(graph, grammar) | ||
intersect_bfs = intersection.intersect_bfs() | ||
intersect_kron = intersection.intersect_kron() | ||
|
||
assert intersect_bfs.accepts(["a"]) | ||
assert intersect_bfs.accepts(["a" for _ in range(10)]) | ||
|
||
source_verts = [0, 2] | ||
result = intersection.intersect_bfs(source_verts) | ||
|
||
assert not intersect_bfs.accepts(["b"]) | ||
assert not intersect_bfs.accepts(["c"]) | ||
assert not intersect_bfs.accepts(["epsilon"]) | ||
|
||
assert intersect_kron.is_equivalent_to(intersect_bfs) | ||
assert result.nvals == 0 * len(source_verts) | ||
|
||
|
||
@pytest.mark.CI | ||
|
@@ -90,18 +69,11 @@ def test_case_regular_midsymbol(): | |
) | ||
|
||
intersection = Intersection(graph, grammar) | ||
intersect_bfs = intersection.intersect_bfs() | ||
intersect_kron = intersection.intersect_kron() | ||
|
||
assert intersect_bfs.accepts(["b"]) | ||
assert intersect_bfs.accepts(["a", "b", "c"]) | ||
assert intersect_bfs.accepts(["a", "a", "b", "c", "c"]) | ||
|
||
assert not intersect_bfs.accepts(["a"]) | ||
assert not intersect_bfs.accepts(["c"]) | ||
assert not intersect_bfs.accepts(["b", "b"]) | ||
|
||
source_verts = [0] | ||
result = intersection.intersect_bfs(source_verts) | ||
|
||
assert intersect_kron.is_equivalent_to(intersect_bfs) | ||
assert result.nvals == 1 * len(source_verts) | ||
|
||
|
||
@pytest.mark.CI | ||
|
@@ -114,13 +86,8 @@ def test_case_regular_two_cycles(): | |
) | ||
|
||
intersection = Intersection(graph, grammar) | ||
intersect_bfs = intersection.intersect_bfs() | ||
intersect_kron = intersection.intersect_kron() | ||
|
||
assert intersect_bfs.accepts(["a"]) | ||
assert intersect_bfs.accepts(["a", "a"]) | ||
|
||
assert not intersect_bfs.accepts(["b"]) | ||
assert not intersect_bfs.accepts(["c"]) | ||
|
||
source_verts = [0, 3] | ||
result = intersection.intersect_bfs(source_verts) | ||
|
||
assert intersect_kron.is_equivalent_to(intersect_bfs) | ||
assert result.nvals == 2 * len(source_verts) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this function solves reachability problem, not automata intersection. Isn't it?