Skip to content

Commit 9d22970

Browse files
authored
Add EQL rules and schema validation (#297)
* Add EQL rules and schema validation * Lint nitpick * Rename get_schema_from_eql * Add EQL default language * Rename parsed_kql to parsed_query * Fix parsed_kql method call in loader * Autopopulate dependent values
1 parent 4041fc8 commit 9d22970

File tree

10 files changed

+207
-38
lines changed

10 files changed

+207
-38
lines changed

detection_rules/beats.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
import os
77

88
import kql
9+
import eql
910
import requests
1011
import yaml
1112

1213
from .semver import Version
13-
from .utils import unzip, load_etc_dump, save_etc_dump, get_etc_path
14+
from .utils import unzip, load_etc_dump, save_etc_dump, get_etc_path, cached
1415

1516

1617
def download_latest_beats_schema():
@@ -129,34 +130,16 @@ def get_beats_sub_schema(schema: dict, beat: str, module: str, *datasets: str):
129130
return {field["name"]: field for field in sorted(flattened, key=lambda f: f["name"])}
130131

131132

132-
SCHEMA = None
133-
134-
133+
@cached
135134
def read_beats_schema():
136-
global SCHEMA
137-
138-
if SCHEMA is None:
139-
beats_schemas = os.listdir(get_etc_path("beats_schemas"))
140-
latest = max(beats_schemas, key=lambda b: Version(b.lstrip("v")))
135+
beats_schemas = os.listdir(get_etc_path("beats_schemas"))
136+
latest = max(beats_schemas, key=lambda b: Version(b.lstrip("v")))
141137

142-
SCHEMA = load_etc_dump("beats_schemas", latest)
138+
return load_etc_dump("beats_schemas", latest)
143139

144-
return SCHEMA
145140

146-
147-
def get_schema_for_query(tree: kql.ast, beats: list) -> dict:
141+
def get_schema_from_datasets(beats, modules, datasets):
148142
filtered = {}
149-
modules = set()
150-
datasets = set()
151-
152-
# extract out event.module and event.dataset from the query's AST
153-
for node in tree:
154-
if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"):
155-
modules.update(child.value for child in node.value if isinstance(child, kql.ast.String))
156-
157-
if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"):
158-
datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String))
159-
160143
beats_schema = read_beats_schema()
161144

162145
# infer the module if only a dataset are defined
@@ -173,3 +156,39 @@ def get_schema_for_query(tree: kql.ast, beats: list) -> dict:
173156
filtered.update(get_beats_sub_schema(beats_schema, beat, module, *datasets))
174157

175158
return filtered
159+
160+
161+
def get_schema_from_eql(tree: eql.ast.BaseNode, beats: list) -> dict:
162+
modules = set()
163+
datasets = set()
164+
165+
# extract out event.module and event.dataset from the query's AST
166+
for node in tree:
167+
if isinstance(node, eql.ast.Comparison) and node.comparator == node.EQ and \
168+
isinstance(node.right, eql.ast.String):
169+
if node.left == eql.ast.Field("event", ["module"]):
170+
modules.add(node.right.render())
171+
elif node.left == eql.ast.Field("event", ["dataset"]):
172+
datasets.add(node.right.render())
173+
elif isinstance(node, eql.ast.InSet):
174+
if node.expression == eql.ast.Field("event", ["module"]):
175+
modules.add(node.get_literals())
176+
elif node.expression == eql.ast.Field("event", ["dataset"]):
177+
datasets.add(node.get_literals())
178+
179+
return get_schema_from_datasets(beats, modules, datasets)
180+
181+
182+
def get_schema_from_kql(tree: kql.ast.BaseNode, beats: list) -> dict:
183+
modules = set()
184+
datasets = set()
185+
186+
# extract out event.module and event.dataset from the query's AST
187+
for node in tree:
188+
if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"):
189+
modules.update(child.value for child in node.value if isinstance(child, kql.ast.String))
190+
191+
if isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"):
192+
datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String))
193+
194+
return get_schema_from_datasets(beats, modules, datasets)

detection_rules/ecs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import json
1111

1212
import requests
13+
import eql
14+
import eql.types
1315
import yaml
1416

1517
from .semver import Version
@@ -164,6 +166,34 @@ def flatten_multi_fields(schema):
164166
return converted
165167

166168

169+
class KqlSchema2Eql(eql.Schema):
170+
type_mapping = {
171+
"keyword": eql.types.TypeHint.String,
172+
"ip": eql.types.TypeHint.String,
173+
"float": eql.types.TypeHint.Numeric,
174+
"double": eql.types.TypeHint.Numeric,
175+
"long": eql.types.TypeHint.Numeric,
176+
"short": eql.types.TypeHint.Numeric,
177+
}
178+
179+
def __init__(self, kql_schema):
180+
self.kql_schema = kql_schema
181+
eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False)
182+
183+
def validate_event_type(self, event_type):
184+
# allow all event types to fill in X:
185+
# `X` where ....
186+
return True
187+
188+
def get_event_type_hint(self, event_type, path):
189+
dotted = ".".join(path)
190+
elasticsearch_type = self.kql_schema.get(dotted)
191+
eql_hint = self.type_mapping.get(elasticsearch_type)
192+
193+
if eql_hint is not None:
194+
return eql_hint, None
195+
196+
167197
@cached
168198
def get_kql_schema(version=None, indexes=None, beat_schema=None):
169199
"""Get schema for KQL."""

detection_rules/misc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def schema_prompt(name, value=None, required=False, **options):
7575
if name == 'rule_id':
7676
default = str(uuid.uuid4())
7777

78+
if len(enum) == 1 and required and field_type != "array":
79+
return enum[0]
80+
7881
def _check_type(_val):
7982
if field_type in ('number', 'integer') and not str(_val).isdigit():
8083
print('Number expected but got: {}'.format(_val))

detection_rules/rule.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import click
1212
import kql
13+
import eql
1314

1415
from . import ecs, beats
1516
from .attack import TACTICS, build_threat_map_entry, technique_lookup
@@ -70,9 +71,12 @@ def query(self):
7071
return self.contents.get('query')
7172

7273
@property
73-
def parsed_kql(self):
74-
if self.query and self.contents['language'] == 'kuery':
75-
return kql.parse(self.query)
74+
def parsed_query(self):
75+
if self.query:
76+
if self.contents['language'] == 'kuery':
77+
return kql.parse(self.query)
78+
elif self.contents['language'] == 'eql':
79+
return eql.parse_query(self.query)
7680

7781
@property
7882
def filters(self):
@@ -152,18 +156,58 @@ def validate(self, as_rule=False, versioned=False, query=True):
152156

153157
schema_cls.validate(contents, role=self.type)
154158

155-
if query and self.query and self.contents['language'] == 'kuery':
159+
if query and self.query is not None:
156160
ecs_versions = self.metadata.get('ecs_version')
157161
indexes = self.contents.get("index", [])
158-
self._validate_kql(ecs_versions, indexes, self.query, self.name)
162+
163+
if self.contents['language'] == 'kuery':
164+
self._validate_kql(ecs_versions, indexes, self.query, self.name)
165+
166+
if self.contents['language'] == 'eql':
167+
self._validate_eql(ecs_versions, indexes, self.query, self.name)
168+
169+
@staticmethod
170+
@cached
171+
def _validate_eql(ecs_versions, indexes, query, name):
172+
# validate against all specified schemas or the latest if none specified
173+
parsed = eql.parse_query(query)
174+
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
175+
beat_schema = beats.get_schema_from_eql(parsed, beat_types) if beat_types else None
176+
177+
ecs_versions = ecs_versions or [ecs_versions]
178+
schemas = []
179+
180+
for version in ecs_versions:
181+
try:
182+
schemas.append(ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema, version=version))
183+
except KeyError:
184+
raise KeyError('Unknown ecs schema version: {} in rule {}.\n'
185+
'Do you need to update schemas?'.format(version, name)) from None
186+
187+
for schema in schemas:
188+
try:
189+
with ecs.KqlSchema2Eql(schema):
190+
eql.parse_query(query)
191+
192+
except eql.EqlTypeMismatchError:
193+
raise
194+
195+
except eql.EqlParseError as exc:
196+
message = exc.error_msg
197+
trailer = None
198+
if "Unknown field" in message and beat_types:
199+
trailer = "\nTry adding event.module and event.dataset to specify beats module"
200+
201+
raise type(exc)(exc.error_msg, exc.line, exc.column, exc.source,
202+
len(exc.caret.lstrip()), trailer=trailer) from None
159203

160204
@staticmethod
161205
@cached
162206
def _validate_kql(ecs_versions, indexes, query, name):
163207
# validate against all specified schemas or the latest if none specified
164208
parsed = kql.parse(query)
165209
beat_types = [index.split("-")[0] for index in indexes if "beat-*" in index]
166-
beat_schema = beats.get_schema_for_query(parsed, beat_types) if beat_types else None
210+
beat_schema = beats.get_schema_from_kql(parsed, beat_types) if beat_types else None
167211

168212
if not ecs_versions:
169213
kql.parse(query, schema=ecs.get_kql_schema(indexes=indexes, beat_schema=beat_schema))

detection_rules/rule_loader.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,13 @@ def load_rules(file_lookup=None, verbose=True, error=True):
9393
raise KeyError("Rule has duplicate name to {}".format(
9494
next(r for r in rules if r.name == rule.name).path))
9595

96-
if rule.parsed_kql:
97-
if rule.parsed_kql in queries:
96+
parsed_query = rule.parsed_query
97+
if parsed_query is not None:
98+
if parsed_query in queries:
9899
raise KeyError("Rule has duplicate query with {}".format(
99-
next(r for r in rules if r.parsed_kql == rule.parsed_kql).path))
100+
next(r for r in rules if r.parsed_query == parsed_query).path))
100101

101-
queries.append(rule.parsed_kql)
102+
queries.append(parsed_query)
102103

103104
if not re.match(FILE_PATTERN, os.path.basename(rule.path)):
104105
raise ValueError(f"Rule {rule.path} does not meet rule name standard of {FILE_PATTERN}")

detection_rules/schemas/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from ..semver import Version
88

99
# import all of the schema versions
10-
from .v78 import ApiSchema78
11-
from .v79 import ApiSchema79
10+
from .v7_8 import ApiSchema78
11+
from .v7_9 import ApiSchema79
12+
from .v7_10 import ApiSchema710
1213

1314
__all__ = (
1415
"all_schemas",
@@ -21,9 +22,10 @@
2122
all_schemas = [
2223
ApiSchema78,
2324
ApiSchema79,
25+
ApiSchema710,
2426
]
2527

26-
CurrentSchema = max(all_schemas, key=lambda cls: Version(cls.STACK_VERSION))
28+
CurrentSchema = all_schemas[-1]
2729

2830

2931
def downgrade(api_contents: dict, target_version: str):

detection_rules/schemas/v7_10.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2+
# or more contributor license agreements. Licensed under the Elastic License;
3+
# you may not use this file except in compliance with the Elastic License.
4+
5+
"""Definitions for rule metadata and schemas."""
6+
7+
import jsl
8+
from .v7_9 import ApiSchema79
9+
10+
11+
# rule types
12+
EQL = "eql"
13+
14+
15+
class ApiSchema710(ApiSchema79):
16+
"""Schema for siem rule in API format."""
17+
18+
STACK_VERSION = "7.10"
19+
RULE_TYPES = ApiSchema79.RULE_TYPES + [EQL]
20+
21+
type = jsl.StringField(enum=RULE_TYPES, required=True)
22+
23+
# there might be a bug in jsl that requires us to redefine these here
24+
query_scope = ApiSchema79.query_scope
25+
saved_id_scope = ApiSchema79.saved_id_scope
26+
ml_scope = ApiSchema79.ml_scope
27+
threshold_scope = ApiSchema79.threshold_scope
28+
29+
with jsl.Scope(EQL) as eql_scope:
30+
eql_scope.index = jsl.ArrayField(jsl.StringField(), required=False)
31+
eql_scope.query = jsl.StringField(required=True)
32+
eql_scope.language = jsl.StringField(enum=[EQL], required=True, default=EQL)
33+
eql_scope.type = jsl.StringField(enum=[EQL], required=True)
34+
35+
with jsl.Scope(jsl.DEFAULT_ROLE) as default_scope:
36+
default_scope.type = type
File renamed without changes.

detection_rules/schemas/v79.py renamed to detection_rules/schemas/v7_9.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""Definitions for rule metadata and schemas."""
66

77
import jsl
8-
from .v78 import ApiSchema78
8+
from .v7_8 import ApiSchema78
99

1010

1111
OPERATORS = ['equals']

tests/test_schemas.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Test stack versioned schemas."""
66
import unittest
77
import uuid
8+
import eql
89

910
from detection_rules.rule import Rule
1011
from detection_rules.schemas import downgrade, CurrentSchema
@@ -106,3 +107,36 @@ def test_threshold_downgrade(self):
106107

107108
with self.assertRaisesRegex(ValueError, "Unsupported rule type"):
108109
downgrade(api_contents, "7.8")
110+
111+
def test_eql_validation(self):
112+
base_fields = {
113+
"author": ["Elastic"],
114+
"description": "test description",
115+
"index": ["filebeat-*"],
116+
"language": "eql",
117+
"license": "Elastic License",
118+
"name": "test rule",
119+
"risk_score": 21,
120+
"rule_id": str(uuid.uuid4()),
121+
"severity": "low",
122+
"type": "eql"
123+
}
124+
125+
Rule("test.toml", dict(base_fields, query="""
126+
process where process.name == "cmd.exe"
127+
"""))
128+
129+
with self.assertRaises(eql.EqlSyntaxError):
130+
Rule("test.toml", dict(base_fields, query="""
131+
process where process.name == this!is$not#v@lid
132+
"""))
133+
134+
with self.assertRaises(eql.EqlSemanticError):
135+
Rule("test.toml", dict(base_fields, query="""
136+
process where process.invalid_field == "hello world"
137+
"""))
138+
139+
with self.assertRaises(eql.EqlTypeMismatchError):
140+
Rule("test.toml", dict(base_fields, query="""
141+
process where process.pid == "some string field"
142+
"""))

0 commit comments

Comments
 (0)