18
18
from marshmallow import ValidationError , validates_schema
19
19
20
20
import kql
21
+ from . import beats
22
+ from . import ecs
21
23
from . import utils
24
+ from .misc import load_current_package_version
22
25
from .mixins import MarshmallowDataclassMixin , StackCompatMixin
23
26
from .rule_formatter import toml_write , nested_normalize
24
27
from .schemas import SCHEMA_DIR , definitions , downgrade , get_stack_schemas , get_min_supported_stack_version
25
28
from .schemas .stack_compat import get_restricted_fields
26
29
from .semver import Version
27
30
from .utils import cached
28
31
32
+ BUILD_FIELD_VERSIONS = {"required_fields" : (Version ('8.3' ), None )}
29
33
_META_SCHEMA_REQ_DEFAULTS = {}
30
34
MIN_FLEET_PACKAGE_VERSION = '7.13.0'
31
35
@@ -149,6 +153,12 @@ class FlatThreatMapping(MarshmallowDataclassMixin):
149
153
150
154
@dataclass (frozen = True )
151
155
class BaseRuleData (MarshmallowDataclassMixin , StackCompatMixin ):
156
+ @dataclass
157
+ class RequiredFields :
158
+ name : definitions .NonEmptyStr
159
+ type : definitions .NonEmptyStr
160
+ ecs : bool
161
+
152
162
actions : Optional [list ]
153
163
author : List [str ]
154
164
building_block_type : Optional [str ]
@@ -171,7 +181,7 @@ class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
171
181
# output_index: Optional[str]
172
182
references : Optional [List [str ]]
173
183
related_integrations : Optional [List [str ]] = field (metadata = dict (metadata = dict (min_compat = "8.3" )))
174
- required_fields : Optional [List [str ]] = field (metadata = dict (metadata = dict (min_compat = "8.3" )))
184
+ required_fields : Optional [List [RequiredFields ]] = field (metadata = dict (metadata = dict (min_compat = "8.3" )))
175
185
risk_score : definitions .RiskScore
176
186
risk_score_mapping : Optional [List [RiskScoreMapping ]]
177
187
rule_id : definitions .UUIDString
@@ -220,9 +230,45 @@ class QueryValidator:
220
230
def ast (self ) -> Any :
221
231
raise NotImplementedError
222
232
233
+ @property
234
+ def unique_fields (self ) -> Any :
235
+ raise NotImplementedError
236
+
223
237
def validate (self , data : 'QueryRuleData' , meta : RuleMeta ) -> None :
224
238
raise NotImplementedError ()
225
239
240
+ @cached
241
+ def get_required_fields (self , index : str ) -> List [dict ]:
242
+ """Retrieves fields needed for the query along with type information from the schema."""
243
+ current_version = Version (Version (load_current_package_version ()) + (0 ,))
244
+ ecs_version = get_stack_schemas ()[str (current_version )]['ecs' ]
245
+ beats_version = get_stack_schemas ()[str (current_version )]['beats' ]
246
+ ecs_schema = ecs .get_schema (ecs_version )
247
+
248
+ beat_types , beat_schema , schema = self .get_beats_schema (index or [], beats_version , ecs_version )
249
+
250
+ required = []
251
+ unique_fields = self .unique_fields or []
252
+
253
+ for fld in unique_fields :
254
+ field_type = ecs_schema .get (fld , {}).get ('type' )
255
+ is_ecs = field_type is not None
256
+
257
+ if beat_schema and not is_ecs :
258
+ field_type = beat_schema .get (fld , {}).get ('type' )
259
+
260
+ required .append (dict (name = fld , type = field_type or 'unknown' , ecs = is_ecs ))
261
+
262
+ return sorted (required , key = lambda f : f ['name' ])
263
+
264
+ @cached
265
+ def get_beats_schema (self , index : list , beats_version : str , ecs_version : str ) -> (list , dict , dict ):
266
+ """Get an assembled beats schema."""
267
+ beat_types = beats .parse_beats_from_index (index )
268
+ beat_schema = beats .get_schema_from_kql (self .ast , beat_types , version = beats_version ) if beat_types else None
269
+ schema = ecs .get_kql_schema (version = ecs_version , indexes = index , beat_schema = beat_schema )
270
+ return beat_types , beat_schema , schema
271
+
226
272
227
273
@dataclass (frozen = True )
228
274
class QueryRuleData (BaseRuleData ):
@@ -251,6 +297,18 @@ def ast(self):
251
297
if validator is not None :
252
298
return validator .ast
253
299
300
+ @cached_property
301
+ def unique_fields (self ):
302
+ validator = self .validator
303
+ if validator is not None :
304
+ return validator .unique_fields
305
+
306
+ @cached
307
+ def get_required_fields (self , index : str ) -> List [dict ]:
308
+ validator = self .validator
309
+ if validator is not None :
310
+ return validator .get_required_fields (index or [])
311
+
254
312
255
313
@dataclass (frozen = True )
256
314
class MachineLearningRuleData (BaseRuleData ):
@@ -438,8 +496,7 @@ def autobumped_version(self) -> Optional[int]:
438
496
439
497
return version + 1 if self .is_dirty else version
440
498
441
- @staticmethod
442
- def _post_dict_transform (obj : dict ) -> dict :
499
+ def _post_dict_transform (self , obj : dict ) -> dict :
443
500
"""Transform the converted API in place before sending to Kibana."""
444
501
445
502
# cleanup the whitespace in the rule
@@ -515,6 +572,59 @@ def name(self) -> str:
515
572
def type (self ) -> str :
516
573
return self .data .type
517
574
575
+ def _post_dict_transform (self , obj : dict ) -> dict :
576
+ """Transform the converted API in place before sending to Kibana."""
577
+ super ()._post_dict_transform (obj )
578
+
579
+ self .add_related_integrations (obj )
580
+ self .add_required_fields (obj )
581
+ self .add_setup (obj )
582
+
583
+ # validate new fields against the schema
584
+ rule_type = obj ['type' ]
585
+ subclass = self .get_data_subclass (rule_type )
586
+ subclass .from_dict (obj )
587
+
588
+ return obj
589
+
590
+ def add_related_integrations (self , obj : dict ) -> None :
591
+ """Add restricted field related_integrations to the obj."""
592
+ # field_name = "related_integrations"
593
+ ...
594
+
595
+ def add_required_fields (self , obj : dict ) -> None :
596
+ """Add restricted field required_fields to the obj, derived from the query AST."""
597
+ if isinstance (self .data , QueryRuleData ) and self .data .language != 'lucene' :
598
+ index = obj .get ('index' ) or []
599
+ required_fields = self .data .get_required_fields (index )
600
+ else :
601
+ required_fields = []
602
+
603
+ field_name = "required_fields"
604
+ if self .check_restricted_field_version (field_name = field_name ):
605
+ obj .setdefault (field_name , required_fields )
606
+
607
+ def add_setup (self , obj : dict ) -> None :
608
+ """Add restricted field setup to the obj."""
609
+ # field_name = "setup"
610
+ ...
611
+
612
+ def check_explicit_restricted_field_version (self , field_name : str ) -> bool :
613
+ """Explicitly check restricted fields against global min and max versions."""
614
+ min_stack , max_stack = BUILD_FIELD_VERSIONS [field_name ]
615
+ return self .compare_field_versions (min_stack , max_stack )
616
+
617
+ def check_restricted_field_version (self , field_name : str ) -> bool :
618
+ """Check restricted fields against schema min and max versions."""
619
+ min_stack , max_stack = self .data .get_restricted_fields .get (field_name )
620
+ return self .compare_field_versions (min_stack , max_stack )
621
+
622
+ def compare_field_versions (self , min_stack : Version , max_stack : Version ) -> bool :
623
+ """Check current rule version is witihin min and max stack versions."""
624
+ current_version = Version (load_current_package_version ())
625
+ max_stack = max_stack or current_version
626
+ return Version (min_stack ) <= current_version >= Version (max_stack )
627
+
518
628
@validates_schema
519
629
def validate_query (self , value : dict , ** kwargs ):
520
630
"""Validate queries by calling into the validator for the relevant method."""
@@ -540,11 +650,11 @@ def flattened_dict(self) -> dict:
540
650
def to_api_format (self , include_version = True ) -> dict :
541
651
"""Convert the TOML rule to the API format."""
542
652
converted = self .data .to_dict ()
653
+ converted = self ._post_dict_transform (converted )
654
+
543
655
if include_version :
544
656
converted ["version" ] = self .autobumped_version
545
657
546
- converted = self ._post_dict_transform (converted )
547
-
548
658
return converted
549
659
550
660
def check_restricted_fields_compatibility (self ) -> Dict [str , dict ]:
0 commit comments