Skip to content

Cleanup rule survey code #1923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions detection_rules/devtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun
from eql.table import Table
from kibana.resources import Signal
from .main import search_rules
# from .eswrap import parse_unique_field_results

survey_results = []
start_time, end_time = date_range
Expand All @@ -916,15 +917,20 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun
click.echo(f'Saving detailed dump to: {dump_file}')

collector = CollectEvents(elasticsearch_client)
details = collector.search_from_rule(*rules, start_time=start_time, end_time=end_time)
counts = collector.count_from_rule(*rules, start_time=start_time, end_time=end_time)
details = collector.search_from_rule(rules, start_time=start_time, end_time=end_time)
counts = collector.count_from_rule(rules, start_time=start_time, end_time=end_time)

# add alerts
with kibana_client:
range_dsl = {'query': {'bool': {'filter': []}}}
add_range_to_dsl(range_dsl['query']['bool']['filter'], start_time, end_time)
alerts = {a['_source']['signal']['rule']['rule_id']: a['_source']
for a in Signal.search(range_dsl)['hits']['hits']}
for a in Signal.search(range_dsl, size=10000)['hits']['hits']}

# for alert in alerts:
# rule_id = alert['signal']['rule']['rule_id']
# rule = rules.id_map[rule_id]
# unique_results = parse_unique_field_results(rule.contents.data.type, rule.contents.data.unique_fields, alert)

for rule_id, count in counts.items():
alert_count = len(alerts.get(rule_id, []))
Expand Down
115 changes: 64 additions & 51 deletions detection_rules/eswrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import time
from collections import defaultdict
from typing import Union
from typing import List, Union

import click
import elasticsearch
Expand All @@ -17,7 +17,7 @@

import kql
from .main import root
from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client
from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client, nested_get
from .rule import TOMLRule
from .rule_loader import rta_mappings, RuleCollection
from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path
Expand All @@ -33,7 +33,23 @@ def add_range_to_dsl(dsl_filter, start_time, end_time='now'):
)


class RtaEvents(object):
def parse_unique_field_results(rule_type: str, unique_fields: List[str], search_results: dict):
parsed_results = defaultdict(lambda: defaultdict(int))
hits = search_results['hits']
hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', [])
for hit in hits:
for field in unique_fields:
match = nested_get(hit['_source'], field)
if not match:
continue

match = ','.join(sorted(match)) if isinstance(match, list) else match
parsed_results[field][match] += 1
# if rule.type == eql, structure is different
return {'results': parsed_results} if parsed_results else {}


class RtaEvents:
"""Events collected from Elasticsearch."""

def __init__(self, events):
Expand Down Expand Up @@ -64,7 +80,7 @@ def evaluate_against_rule_and_update_mapping(self, rule_id, rta_name, verbose=Tr
"""Evaluate a rule against collected events and update mapping."""
from .utils import combine_sources, evaluate

rule = next((rule for rule in RuleCollection.default() if rule.id == rule_id), None)
rule = RuleCollection.default().id_map.get(rule_id)
assert rule is not None, f"Unable to find rule with ID {rule_id}"
merged_events = combine_sources(*self.events.values())
filtered = evaluate(rule, merged_events)
Expand Down Expand Up @@ -112,7 +128,7 @@ def _build_timestamp_map(self, index_str):

def _get_last_event_time(self, index_str, dsl=None):
"""Get timestamp of most recent event."""
last_event = self.client.search(dsl, index_str, size=1, sort='@timestamp:desc')['hits']['hits']
last_event = self.client.search(query=dsl, index=index_str, size=1, sort='@timestamp:desc')['hits']['hits']
if not last_event:
return

Expand Down Expand Up @@ -146,7 +162,7 @@ def _prep_query(query, language, index, start_time=None, end_time=None):
elif language == 'dsl':
formatted_dsl = {'query': query}
else:
raise ValueError('Unknown search language')
raise ValueError(f'Unknown search language: {language}')

if start_time or end_time:
end_time = end_time or 'now'
Expand All @@ -172,84 +188,79 @@ def search(self, query, language, index: Union[str, list] = '*', start_time=None

return results

def search_from_rule(self, *rules: TOMLRule, start_time=None, end_time='now', size=None):
def search_from_rule(self, rules: RuleCollection, start_time=None, end_time='now', size=None):
"""Search an elasticsearch instance using a rule."""
from .misc import nested_get

async_client = AsyncSearchClient(self.client)
survey_results = {}

def parse_unique_field_results(rule_type, unique_fields, search_results):
parsed_results = defaultdict(lambda: defaultdict(int))
hits = search_results['hits']
hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', [])
for hit in hits:
for field in unique_fields:
match = nested_get(hit['_source'], field)
match = ','.join(sorted(match)) if isinstance(match, list) else match
parsed_results[field][match] += 1
# if rule.type == eql, structure is different
return {'results': parsed_results} if parsed_results else {}

multi_search = []
multi_search_rules = []
async_searches = {}
eql_searches = {}
async_searches = []
eql_searches = []

for rule in rules:
if not rule.query:
if not rule.contents.data.get('query'):
continue

index_str, formatted_dsl, lucene_query = self._prep_query(query=rule.query,
language=rule.contents.get('language'),
index=rule.contents.get('index', '*'),
language = rule.contents.data.get('language')
query = rule.contents.data.query
rule_type = rule.contents.data.type
index_str, formatted_dsl, lucene_query = self._prep_query(query=query,
language=language,
index=rule.contents.data.get('index', '*'),
start_time=start_time,
end_time=end_time)
formatted_dsl.update(size=size or self.max_events)

# prep for searches: msearch for kql | async search for lucene | eql client search for eql
if rule.contents['language'] == 'kuery':
if language == 'kuery':
multi_search_rules.append(rule)
multi_search.append(json.dumps(
{'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'}))
multi_search.append(json.dumps(formatted_dsl))
elif rule.contents['language'] == 'lucene':
multi_search.append({'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'})
multi_search.append(formatted_dsl)
elif language == 'lucene':
# wait for 0 to try and force async with no immediate results (not guaranteed)
result = async_client.submit(body=formatted_dsl, q=rule.query, index=index_str,
result = async_client.submit(body=formatted_dsl, q=query, index=index_str,
allow_no_indices=True, ignore_unavailable=True,
wait_for_completion_timeout=0)
if result['is_running'] is True:
async_searches[rule] = result['id']
async_searches.append((rule, result['id']))
else:
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields,
survey_results[rule.id] = parse_unique_field_results(rule_type, rule.contents.data.unique_fields,
result['response'])
elif rule.contents['language'] == 'eql':
elif language == 'eql':
eql_body = {
'index': index_str,
'params': {'ignore_unavailable': 'true', 'allow_no_indices': 'true'},
'body': {'query': rule.query, 'filter': formatted_dsl['filter']}
'body': {'query': query, 'filter': formatted_dsl['filter']}
}
eql_searches[rule] = eql_body
eql_searches.append((rule, eql_body))

# assemble search results
multi_search_results = self.client.msearch('\n'.join(multi_search) + '\n')
multi_search_results = self.client.msearch(searches=multi_search)
for index, result in enumerate(multi_search_results['responses']):
try:
rule = multi_search_rules[index]
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
rule.contents.data.unique_fields, result)
except KeyError:
survey_results[multi_search_rules[index].id] = {'error_retrieving_results': True}

for rule, search_args in eql_searches.items():
for entry in eql_searches:
rule: TOMLRule
search_args: dict
rule, search_args = entry
try:
result = self.client.eql.search(**search_args)
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
rule.contents.data.unique_fields, result)
except (elasticsearch.NotFoundError, elasticsearch.RequestError) as e:
survey_results[rule.id] = {'error_retrieving_results': True, 'error': e.info['error']['reason']}

for rule, async_id in async_searches.items():
result = async_client.get(async_id)['response']
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
for entry in async_searches:
rule: TOMLRule
rule, async_id = entry
result = async_client.get(id=async_id)['response']
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
rule.contents.data.unique_fields, result)

return survey_results

Expand All @@ -267,19 +278,21 @@ def count(self, query, language, index: Union[str, list], start_time=None, end_t
return self.client.count(body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True,
ignore_unavailable=True)['count']

def count_from_rule(self, *rules, start_time=None, end_time='now'):
def count_from_rule(self, rules: RuleCollection, start_time=None, end_time='now'):
"""Get a count of documents from elasticsearch using a rule."""
survey_results = {}

for rule in rules:
for rule in rules.rules:
rule_results = {'rule_id': rule.id, 'name': rule.name}

if not rule.query:
if not rule.contents.data.get('query'):
continue

try:
rule_results['search_count'] = self.count(query=rule.query, language=rule.contents.get('language'),
index=rule.contents.get('index', '*'), start_time=start_time,
rule_results['search_count'] = self.count(query=rule.contents.data.query,
language=rule.contents.data.language,
index=rule.contents.get('index', '*'),
start_time=start_time,
end_time=end_time)
except (elasticsearch.NotFoundError, elasticsearch.RequestError):
rule_results['search_count'] = -1
Expand Down
5 changes: 3 additions & 2 deletions detection_rules/kbwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def upload_rule(ctx, rules, replace_id):
@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search')
@click.option('--columns', '-c', multiple=True, help='Columns to display in table')
@click.option('--extend', '-e', is_flag=True, help='If columns are specified, extend the original columns')
@click.option('--max-count', '-m', default=100, help='The max number of alerts to return')
@click.pass_context
def search_alerts(ctx, query, date_range, columns, extend):
def search_alerts(ctx, query, date_range, columns, extend, max_count):
"""Search detection engine alerts with KQL."""
from eql.table import Table
from .eswrap import MATCH_ALL, add_range_to_dsl
Expand All @@ -94,7 +95,7 @@ def search_alerts(ctx, query, date_range, columns, extend):
add_range_to_dsl(kql_query['bool'].setdefault('filter', []), start_time, end_time)

with kibana:
alerts = [a['_source'] for a in Signal.search({'query': kql_query})['hits']['hits']]
alerts = [a['_source'] for a in Signal.search({'query': kql_query}, size=max_count)['hits']['hits']]

table_columns = ['host.hostname', 'signal.rule.name', 'signal.status', 'signal.original_time']
if columns:
Expand Down
4 changes: 2 additions & 2 deletions detection_rules/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
RTA_DIR = get_path("rta")


class RtaMappings(object):
class RtaMappings:
"""Rta-mapping helper class."""

def __init__(self):
"""Rta-mapping validation and prep."""
self.mapping = load_etc_dump('rule-mapping.yml') # type: dict
self.mapping: dict = load_etc_dump('rule-mapping.yml')
self.validate()

self._rta_mapping = defaultdict(list)
Expand Down
11 changes: 10 additions & 1 deletion detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ class QueryValidator:

@property
def ast(self) -> Any:
raise NotImplementedError
raise NotImplementedError()

def unique_fields(self):
raise NotImplementedError()

def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
raise NotImplementedError()
Expand Down Expand Up @@ -240,6 +243,12 @@ def ast(self):
if validator is not None:
return validator.ast

@cached_property
def unique_fields(self):
validator = self.validator
if validator is not None:
return validator.unique_fields


@dataclass(frozen=True)
class MachineLearningRuleData(BaseRuleData):
Expand Down
11 changes: 6 additions & 5 deletions kibana/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# 2.0.

import datetime
from typing import List, Type
from typing import List, Optional, Type

from .connector import Kibana

Expand Down Expand Up @@ -150,8 +150,9 @@ def __init__(self):
raise NotImplementedError("Signals can't be instantiated yet")

@classmethod
def search(cls, query_dsl: dict):
return Kibana.current().post(f"{cls.BASE_URI}/search", data=query_dsl)
def search(cls, query_dsl: dict, size: Optional[int] = 10):
payload = dict(size=size, **query_dsl)
return Kibana.current().post(f"{cls.BASE_URI}/search", data=payload)

@classmethod
def last_signal(cls) -> (int, datetime.datetime):
Expand Down Expand Up @@ -179,8 +180,8 @@ def last_signal(cls) -> (int, datetime.datetime):
return num_signals, last_seen

@classmethod
def all(cls):
return cls.search({"query": {"bool": {"filter": {"match_all": {}}}}})
def all(cls, size: Optional[int] = 10):
return cls.search({"query": {"bool": {"filter": {"match_all": {}}}}}, size=size)

@classmethod
def set_status_many(cls, signal_ids: List[str], status: str) -> dict:
Expand Down