Skip to content

Add new required_fields as a build-time restricted field #2059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 115 additions & 5 deletions detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
from marshmallow import ValidationError, validates_schema

import kql
from . import beats
from . import ecs
from . import utils
from .misc import load_current_package_version
from .mixins import MarshmallowDataclassMixin, StackCompatMixin
from .rule_formatter import toml_write, nested_normalize
from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version
from .schemas.stack_compat import get_restricted_fields
from .semver import Version
from .utils import cached

BUILD_FIELD_VERSIONS = {"required_fields": (Version('8.3'), None)}
_META_SCHEMA_REQ_DEFAULTS = {}
MIN_FLEET_PACKAGE_VERSION = '7.13.0'

Expand Down Expand Up @@ -149,6 +153,12 @@ class FlatThreatMapping(MarshmallowDataclassMixin):

@dataclass(frozen=True)
class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
@dataclass
class RequiredFields:
name: definitions.NonEmptyStr
type: definitions.NonEmptyStr
ecs: bool

actions: Optional[list]
author: List[str]
building_block_type: Optional[str]
Expand All @@ -171,7 +181,7 @@ class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
# output_index: Optional[str]
references: Optional[List[str]]
related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
required_fields: Optional[List[RequiredFields]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
risk_score: definitions.RiskScore
risk_score_mapping: Optional[List[RiskScoreMapping]]
rule_id: definitions.UUIDString
Expand Down Expand Up @@ -220,9 +230,45 @@ class QueryValidator:
def ast(self) -> Any:
raise NotImplementedError

@property
def unique_fields(self) -> Any:
raise NotImplementedError

def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
raise NotImplementedError()

@cached
def get_required_fields(self, index: str) -> List[dict]:
"""Retrieves fields needed for the query along with type information from the schema."""
current_version = Version(Version(load_current_package_version()) + (0,))
ecs_version = get_stack_schemas()[str(current_version)]['ecs']
beats_version = get_stack_schemas()[str(current_version)]['beats']
ecs_schema = ecs.get_schema(ecs_version)

beat_types, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version)

required = []
unique_fields = self.unique_fields or []

for fld in unique_fields:
field_type = ecs_schema.get(fld, {}).get('type')
is_ecs = field_type is not None

if beat_schema and not is_ecs:
field_type = beat_schema.get(fld, {}).get('type')

required.append(dict(name=fld, type=field_type or 'unknown', ecs=is_ecs))

return sorted(required, key=lambda f: f['name'])

@cached
def get_beats_schema(self, index: list, beats_version: str, ecs_version: str) -> (list, dict, dict):
"""Get an assembled beats schema."""
beat_types = beats.parse_beats_from_index(index)
beat_schema = beats.get_schema_from_kql(self.ast, beat_types, version=beats_version) if beat_types else None
schema = ecs.get_kql_schema(version=ecs_version, indexes=index, beat_schema=beat_schema)
return beat_types, beat_schema, schema


@dataclass(frozen=True)
class QueryRuleData(BaseRuleData):
Expand Down Expand Up @@ -251,6 +297,18 @@ def ast(self):
if validator is not None:
return validator.ast

@cached_property
def unique_fields(self):
validator = self.validator
if validator is not None:
return validator.unique_fields

@cached
def get_required_fields(self, index: str) -> List[dict]:
validator = self.validator
if validator is not None:
return validator.get_required_fields(index or [])


@dataclass(frozen=True)
class MachineLearningRuleData(BaseRuleData):
Expand Down Expand Up @@ -438,8 +496,7 @@ def autobumped_version(self) -> Optional[int]:

return version + 1 if self.is_dirty else version

@staticmethod
def _post_dict_transform(obj: dict) -> dict:
def _post_dict_transform(self, obj: dict) -> dict:
"""Transform the converted API in place before sending to Kibana."""

# cleanup the whitespace in the rule
Expand Down Expand Up @@ -515,6 +572,59 @@ def name(self) -> str:
def type(self) -> str:
return self.data.type

def _post_dict_transform(self, obj: dict) -> dict:
"""Transform the converted API in place before sending to Kibana."""
super()._post_dict_transform(obj)

self.add_related_integrations(obj)
self.add_required_fields(obj)
self.add_setup(obj)

# validate new fields against the schema
rule_type = obj['type']
subclass = self.get_data_subclass(rule_type)
subclass.from_dict(obj)

return obj

def add_related_integrations(self, obj: dict) -> None:
"""Add restricted field related_integrations to the obj."""
# field_name = "related_integrations"
...

def add_required_fields(self, obj: dict) -> None:
"""Add restricted field required_fields to the obj, derived from the query AST."""
if isinstance(self.data, QueryRuleData) and self.data.language != 'lucene':
index = obj.get('index') or []
required_fields = self.data.get_required_fields(index)
else:
required_fields = []

field_name = "required_fields"
if self.check_restricted_field_version(field_name=field_name):
obj.setdefault(field_name, required_fields)

def add_setup(self, obj: dict) -> None:
"""Add restricted field setup to the obj."""
# field_name = "setup"
...

def check_explicit_restricted_field_version(self, field_name: str) -> bool:
"""Explicitly check restricted fields against global min and max versions."""
min_stack, max_stack = BUILD_FIELD_VERSIONS[field_name]
return self.compare_field_versions(min_stack, max_stack)

def check_restricted_field_version(self, field_name: str) -> bool:
"""Check restricted fields against schema min and max versions."""
min_stack, max_stack = self.data.get_restricted_fields.get(field_name)
return self.compare_field_versions(min_stack, max_stack)

def compare_field_versions(self, min_stack: Version, max_stack: Version) -> bool:
"""Check current rule version is witihin min and max stack versions."""
current_version = Version(load_current_package_version())
max_stack = max_stack or current_version
return Version(min_stack) <= current_version >= Version(max_stack)

@validates_schema
def validate_query(self, value: dict, **kwargs):
"""Validate queries by calling into the validator for the relevant method."""
Expand All @@ -540,11 +650,11 @@ def flattened_dict(self) -> dict:
def to_api_format(self, include_version=True) -> dict:
"""Convert the TOML rule to the API format."""
converted = self.data.to_dict()
converted = self._post_dict_transform(converted)

if include_version:
converted["version"] = self.autobumped_version

converted = self._post_dict_transform(converted)

return converted

def check_restricted_fields_compatibility(self) -> Dict[str, dict]:
Expand Down
20 changes: 6 additions & 14 deletions detection_rules/rule_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import eql

import kql
from . import ecs, beats
from . import ecs
from .rule import QueryValidator, QueryRuleData, RuleMeta


Expand All @@ -21,17 +21,15 @@ class KQLValidator(QueryValidator):
def ast(self) -> kql.ast.Expression:
return kql.parse(self.query)

@property
@cached_property
def unique_fields(self) -> List[str]:
return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field)))

def to_eql(self) -> eql.ast.Expression:
return kql.to_eql(self.query)

def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
"""Static method to validate the query, called from the parent which contains [metadata] information."""
ast = self.ast

"""Validate the query, called from the parent which contains [metadata] information."""
if meta.query_schema_validation is False or meta.maturity == "deprecated":
# syntax only, which is done via self.ast
return
Expand All @@ -41,9 +39,7 @@ def validate(self, data: QueryRuleData, meta: RuleMeta) -> None:
ecs_version = mapping['ecs']
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'

beat_types = beats.parse_beats_from_index(data.index)
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)

try:
kql.parse(self.query, schema=schema)
Expand Down Expand Up @@ -73,14 +69,12 @@ def text_fields(self, eql_schema: ecs.KqlSchema2Eql) -> List[str]:

return [f for f in self.unique_fields if elasticsearch_type_family(eql_schema.kql_schema.get(f)) == 'text']

@property
@cached_property
def unique_fields(self) -> List[str]:
return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field)))

def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
"""Validate an EQL query while checking TOMLRule."""
ast = self.ast

if meta.query_schema_validation is False or meta.maturity == "deprecated":
# syntax only, which is done via self.ast
return
Expand All @@ -90,9 +84,7 @@ def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
ecs_version = mapping['ecs']
err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}'

beat_types = beats.parse_beats_from_index(data.index)
beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None
schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema)
beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version)
eql_schema = ecs.KqlSchema2Eql(schema)

try:
Expand Down