Skip to content

Commit 52add60

Browse files
Mikaayensongithub-actions[bot]
authored andcommitted
Add new required_fields as a build-time restricted field (#2059)
* Add new `require_field` restricted field * validate new fields against BaseRuleData schema and global constant Co-authored-by: Terrance DeJesus <[email protected]> Co-authored-by: brokensound77 <[email protected]> (cherry picked from commit c76a397)
1 parent 8eb6362 commit 52add60

File tree

2 files changed

+121
-19
lines changed

2 files changed

+121
-19
lines changed

detection_rules/rule.py

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@
1818
from marshmallow import ValidationError, validates_schema
1919

2020
import kql
21+
from . import beats
22+
from . import ecs
2123
from . import utils
24+
from .misc import load_current_package_version
2225
from .mixins import MarshmallowDataclassMixin, StackCompatMixin
2326
from .rule_formatter import toml_write, nested_normalize
2427
from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version
2528
from .schemas.stack_compat import get_restricted_fields
2629
from .semver import Version
2730
from .utils import cached
2831

32+
BUILD_FIELD_VERSIONS = {"required_fields": (Version('8.3'), None)}
2933
_META_SCHEMA_REQ_DEFAULTS = {}
3034
MIN_FLEET_PACKAGE_VERSION = '7.13.0'
3135

@@ -149,6 +153,12 @@ class FlatThreatMapping(MarshmallowDataclassMixin):
149153

150154
@dataclass(frozen=True)
151155
class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
156+
@dataclass
157+
class RequiredFields:
158+
name: definitions.NonEmptyStr
159+
type: definitions.NonEmptyStr
160+
ecs: bool
161+
152162
actions: Optional[list]
153163
author: List[str]
154164
building_block_type: Optional[str]
@@ -171,7 +181,7 @@ class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
171181
# output_index: Optional[str]
172182
references: Optional[List[str]]
173183
related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
174-
required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
184+
required_fields: Optional[List[RequiredFields]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
175185
risk_score: definitions.RiskScore
176186
risk_score_mapping: Optional[List[RiskScoreMapping]]
177187
rule_id: definitions.UUIDString
@@ -220,9 +230,45 @@ class QueryValidator:
220230
def ast(self) -> Any:
221231
raise NotImplementedError
222232

233+
@property
234+
def unique_fields(self) -> Any:
235+
raise NotImplementedError
236+
223237
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
224238
raise NotImplementedError()
225239

240+
@cached
241+
def get_required_fields(self, index: str) -> List[dict]:
242+
"""Retrieves fields needed for the query along with type information from the schema."""
243+
current_version = Version(Version(load_current_package_version()) + (0,))
244+
ecs_version = get_stack_schemas()[str(current_version)]['ecs']
245+
beats_version = get_stack_schemas()[str(current_version)]['beats']
246+
ecs_schema = ecs.get_schema(ecs_version)
247+
248+
beat_types, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version)
249+
250+
required = []
251+
unique_fields = self.unique_fields or []
252+
253+
for fld in unique_fields:
254+
field_type = ecs_schema.get(fld, {}).get('type')
255+
is_ecs = field_type is not None
256+
257+
if beat_schema and not is_ecs:
258+
field_type = beat_schema.get(fld, {}).get('type')
259+
260+
required.append(dict(name=fld, type=field_type or 'unknown', ecs=is_ecs))
261+
262+
return sorted(required, key=lambda f: f['name'])
263+
264+
@cached
265+
def get_beats_schema(self, index: list, beats_version: str, ecs_version: str) -> (list, dict, dict):
266+
"""Get an assembled beats schema."""
267+
beat_types = beats.parse_beats_from_index(index)
268+
beat_schema = beats.get_schema_from_kql(self.ast, beat_types, version=beats_version) if beat_types else None
269+
schema = ecs.get_kql_schema(version=ecs_version, indexes=index, beat_schema=beat_schema)
270+
return beat_types, beat_schema, schema
271+
226272

227273
@dataclass(frozen=True)
228274
class QueryRuleData(BaseRuleData):
@@ -251,6 +297,18 @@ def ast(self):
251297
if validator is not None:
252298
return validator.ast
253299

300+
@cached_property
301+
def unique_fields(self):
302+
validator = self.validator
303+
if validator is not None:
304+
return validator.unique_fields
305+
306+
@cached
307+
def get_required_fields(self, index: str) -> List[dict]:
308+
validator = self.validator
309+
if validator is not None:
310+
return validator.get_required_fields(index or [])
311+
254312

255313
@dataclass(frozen=True)
256314
class MachineLearningRuleData(BaseRuleData):
@@ -438,8 +496,7 @@ def autobumped_version(self) -> Optional[int]:
438496

439497
return version + 1 if self.is_dirty else version
440498

441-
@staticmethod
442-
def _post_dict_transform(obj: dict) -> dict:
499+
def _post_dict_transform(self, obj: dict) -> dict:
443500
"""Transform the converted API in place before sending to Kibana."""
444501

445502
# cleanup the whitespace in the rule
@@ -515,6 +572,59 @@ def name(self) -> str:
515572
def type(self) -> str:
516573
return self.data.type
517574

575+
def _post_dict_transform(self, obj: dict) -> dict:
576+
"""Transform the converted API in place before sending to Kibana."""
577+
super()._post_dict_transform(obj)
578+
579+
self.add_related_integrations(obj)
580+
self.add_required_fields(obj)
581+
self.add_setup(obj)
582+
583+
# validate new fields against the schema
584+
rule_type = obj['type']
585+
subclass = self.get_data_subclass(rule_type)
586+
subclass.from_dict(obj)
587+
588+
return obj
589+
590+
def add_related_integrations(self, obj: dict) -> None:
591+
"""Add restricted field related_integrations to the obj."""
592+
# field_name = "related_integrations"
593+
...
594+
595+
def add_required_fields(self, obj: dict) -> None:
596+
"""Add restricted field required_fields to the obj, derived from the query AST."""
597+
if isinstance(self.data, QueryRuleData) and self.data.language != 'lucene':
598+
index = obj.get('index') or []
599+
required_fields = self.data.get_required_fields(index)
600+
else:
601+
required_fields = []
602+
603+
field_name = "required_fields"
604+
if self.check_restricted_field_version(field_name=field_name):
605+
obj.setdefault(field_name, required_fields)
606+
607+
def add_setup(self, obj: dict) -> None:
608+
"""Add restricted field setup to the obj."""
609+
# field_name = "setup"
610+
...
611+
612+
def check_explicit_restricted_field_version(self, field_name: str) -> bool:
613+
"""Explicitly check restricted fields against global min and max versions."""
614+
min_stack, max_stack = BUILD_FIELD_VERSIONS[field_name]
615+
return self.compare_field_versions(min_stack, max_stack)
616+
617+
def check_restricted_field_version(self, field_name: str) -> bool:
618+
"""Check restricted fields against schema min and max versions."""
619+
min_stack, max_stack = self.data.get_restricted_fields.get(field_name)
620+
return self.compare_field_versions(min_stack, max_stack)
621+
622+
def compare_field_versions(self, min_stack: Version, max_stack: Version) -> bool:
623+
"""Check current rule version is witihin min and max stack versions."""
624+
current_version = Version(load_current_package_version())
625+
max_stack = max_stack or current_version
626+
return Version(min_stack) <= current_version >= Version(max_stack)
627+
518628
@validates_schema
519629
def validate_query(self, value: dict, **kwargs):
520630
"""Validate queries by calling into the validator for the relevant method."""
@@ -540,11 +650,11 @@ def flattened_dict(self) -> dict:
540650
def to_api_format(self, include_version=True) -> dict:
541651
"""Convert the TOML rule to the API format."""
542652
converted = self.data.to_dict()
653+
converted = self._post_dict_transform(converted)
654+
543655
if include_version:
544656
converted["version"] = self.autobumped_version
545657

546-
converted = self._post_dict_transform(converted)
547-
548658
return converted
549659

550660
def check_restricted_fields_compatibility(self) -> Dict[str, dict]:

detection_rules/rule_validators.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import eql
1111

1212
import kql
13-
from . import ecs, beats
13+
from . import ecs
1414
from .rule import QueryValidator, QueryRuleData, RuleMeta
1515

1616

@@ -21,17 +21,15 @@ class KQLValidator(QueryValidator):
2121
def ast(self) -> kql.ast.Expression:
2222
return kql.parse(self.query)
2323

24-
@property
24+
@cached_property
2525
def unique_fields(self) -> List[str]:
2626
return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field)))
2727

2828
def to_eql(self) -> eql.ast.Expression:
2929
return kql.to_eql(self.query)
3030

3131
def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
32-
"""Static method to validate the query, called from the parent which contains [metadata] information."""
33-
ast = self.ast
34-
32+
"""Validate the query, called from the parent which contains [metadata] information."""
3533
if meta.query_schema_validation is False or meta.maturity == "deprecated":
3634
# syntax only, which is done via self.ast
3735
return
@@ -41,9 +39,7 @@ def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
4139
ecs_version = mapping['ecs']
4240
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'
4341

44-
beat_types = beats.parse_beats_from_index(data.index)
45-
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
46-
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
42+
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)
4743

4844
try:
4945
kql.parse(self.query, schema=schema)
@@ -73,14 +69,12 @@ def text_fields(self, eql_schema: ecs.KqlSchema2Eql) -> List[str]:
7369

7470
return [f for f in self.unique_fields if elasticsearch_type_family(eql_schema.kql_schema.get(f)) == 'text']
7571

76-
@property
72+
@cached_property
7773
def unique_fields(self) -> List[str]:
7874
return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field)))
7975

8076
def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
8177
"""Validate an EQL query while checking TOMLRule."""
82-
ast = self.ast
83-
8478
if meta.query_schema_validation is False or meta.maturity == "deprecated":
8579
# syntax only, which is done via self.ast
8680
return
@@ -90,9 +84,7 @@ def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
9084
ecs_version = mapping['ecs']
9185
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'
9286

93-
beat_types = beats.parse_beats_from_index(data.index)
94-
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
95-
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
87+
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)
9688
eql_schema = ecs.KqlSchema2Eql(schema)
9789

9890
try:

0 commit comments

Comments
 (0)