Skip to content

Commit 179a3bd

Browse files
brokensound77github-actions[bot]
authored andcommitted
Add support for restricted fields (#2053)
* Add support for restricted fields (fields valid only in min/max stack versions) * add test to ensure rule backports wont exceed min compat (cherry picked from commit cc01d3f)
1 parent eb6deea commit 179a3bd

File tree

5 files changed

+129
-6
lines changed

5 files changed

+129
-6
lines changed

detection_rules/mixins.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66
"""Generic mixin classes."""
77

88
from pathlib import Path
9-
from typing import TypeVar, Type, Optional, Any
9+
from typing import Any, Optional, TypeVar, Type
1010

1111
import json
1212
import marshmallow_dataclass
1313
import marshmallow_dataclass.union_field
1414
import marshmallow_jsonschema
1515
import marshmallow_union
16-
from marshmallow import Schema, ValidationError, fields
16+
from marshmallow import Schema, ValidationError, fields, validates_schema
1717

18+
from .misc import load_current_package_version
1819
from .schemas import definitions
20+
from .schemas.stack_compat import get_incompatible_fields
21+
from .semver import Version
1922
from .utils import cached, dict_hash
2023

2124
T = TypeVar('T')
@@ -171,6 +174,26 @@ def save_to_file(self, lock_file: Optional[Path] = None):
171174
path.write_text(json.dumps(contents, indent=2, sort_keys=True))
172175

173176

177+
class StackCompatMixin:
178+
"""Mixin to restrict schema compatibility to defined stack versions."""
179+
180+
@validates_schema
181+
def validate_field_compatibility(self, data: dict, **kwargs):
182+
"""Verify stack-specific fields are properly applied to schema."""
183+
package_version = Version(load_current_package_version())
184+
schema_fields = getattr(self, 'fields', {})
185+
incompatible = get_incompatible_fields(list(schema_fields.values()), package_version)
186+
if not incompatible:
187+
return
188+
189+
package_version = load_current_package_version()
190+
for field, bounds in incompatible.items():
191+
min_compat, max_compat = bounds
192+
if data.get(field) is not None:
193+
raise ValidationError(f'Invalid field: "{field}" for stack version: {package_version}, '
194+
f'min compatibility: {min_compat}, max compatibility: {max_compat}')
195+
196+
174197
class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema):
175198

176199
# Patch marshmallow-jsonschema to support marshmallow-dataclass[union]

detection_rules/rule.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@
1111
from dataclasses import dataclass, field
1212
from functools import cached_property
1313
from pathlib import Path
14-
from typing import Literal, Union, Optional, List, Any, Dict
14+
from typing import Literal, Union, Optional, List, Any, Dict, Tuple
1515
from uuid import uuid4
1616

1717
import eql
1818
from marshmallow import ValidationError, validates_schema
1919

2020
import kql
2121
from . import utils
22-
from .mixins import MarshmallowDataclassMixin
22+
from .mixins import MarshmallowDataclassMixin, StackCompatMixin
2323
from .rule_formatter import toml_write, nested_normalize
2424
from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version
25+
from .schemas.stack_compat import get_restricted_fields
26+
from .semver import Version
2527
from .utils import cached
2628

2729
_META_SCHEMA_REQ_DEFAULTS = {}
@@ -146,7 +148,7 @@ class FlatThreatMapping(MarshmallowDataclassMixin):
146148

147149

148150
@dataclass(frozen=True)
149-
class BaseRuleData(MarshmallowDataclassMixin):
151+
class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin):
150152
actions: Optional[list]
151153
author: List[str]
152154
building_block_type: Optional[str]
@@ -168,10 +170,13 @@ class BaseRuleData(MarshmallowDataclassMixin):
168170
# explicitly NOT allowed!
169171
# output_index: Optional[str]
170172
references: Optional[List[str]]
173+
related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
174+
required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3")))
171175
risk_score: definitions.RiskScore
172176
risk_score_mapping: Optional[List[RiskScoreMapping]]
173177
rule_id: definitions.UUIDString
174178
rule_name_override: Optional[str]
179+
setup: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.3")))
175180
severity_mapping: Optional[List[SeverityMapping]]
176181
severity: definitions.Severity
177182
tags: Optional[List[str]]
@@ -186,7 +191,7 @@ class BaseRuleData(MarshmallowDataclassMixin):
186191
@classmethod
187192
def save_schema(cls):
188193
"""Save the schema as a jsonschema."""
189-
fields: List[dataclasses.Field] = dataclasses.fields(cls)
194+
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(cls)
190195
type_field = next(f for f in fields if f.name == "type")
191196
rule_type = typing.get_args(type_field.type)[0] if cls != BaseRuleData else "base"
192197
schema = cls.jsonschema()
@@ -200,6 +205,12 @@ def save_schema(cls):
200205
def validate_query(self, meta: RuleMeta) -> None:
201206
pass
202207

208+
@cached_property
209+
def get_restricted_fields(self) -> Optional[Dict[str, tuple]]:
210+
"""Get stack version restricted fields."""
211+
fields: List[dataclasses.Field, ...] = list(dataclasses.fields(self))
212+
return get_restricted_fields(fields)
213+
203214

204215
@dataclass
205216
class QueryValidator:
@@ -536,6 +547,24 @@ def to_api_format(self, include_version=True) -> dict:
536547

537548
return converted
538549

550+
def check_restricted_fields_compatibility(self) -> Dict[str, dict]:
551+
"""Check for compatibility between restricted fields and the min_stack_version of the rule."""
552+
default_min_stack = get_min_supported_stack_version(drop_patch=True)
553+
if self.metadata.min_stack_version is not None:
554+
min_stack = Version(self.metadata.min_stack_version)
555+
else:
556+
min_stack = default_min_stack
557+
restricted = self.data.get_restricted_fields
558+
559+
invalid = {}
560+
for _field, values in restricted.items():
561+
if self.data.get(_field) is not None:
562+
min_allowed, _ = values
563+
if min_stack < min_allowed:
564+
invalid[_field] = {'min_stack_version': min_stack, 'min_allowed_version': min_allowed}
565+
566+
return invalid
567+
539568

540569
@dataclass
541570
class TOMLRule:

detection_rules/schemas/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def get_stack_versions(drop_patch=False) -> List[str]:
268268
return versions
269269

270270

271+
@cached
271272
def get_min_supported_stack_version(drop_patch=False) -> Version:
272273
"""Get the minimum defined and supported stack version."""
273274
stack_map = load_stack_schema_map()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2+
# or more contributor license agreements. Licensed under the Elastic License
3+
# 2.0; you may not use this file except in compliance with the Elastic License
4+
# 2.0.
5+
6+
from dataclasses import Field
7+
from typing import Dict, List, Optional, Tuple
8+
9+
from ..misc import cached
10+
from ..semver import Version
11+
12+
13+
@cached
14+
def get_restricted_field(schema_field: Field) -> Tuple[Optional[Version], Optional[Version]]:
15+
"""Get an optional min and max compatible versions of a field (from a schema or dataclass)."""
16+
# nested get is to support schema fields being passed directly from dataclass or fields in schema class, since
17+
# marshmallow_dataclass passes the embedded metadata directly
18+
min_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('min_compat')
19+
max_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('max_compat')
20+
min_compat = Version(min_compat) if min_compat else None
21+
max_compat = Version(max_compat) if max_compat else None
22+
return min_compat, max_compat
23+
24+
25+
@cached
26+
def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optional[Version], Optional[Version]]]:
27+
"""Get a list of optional min and max compatible versions of fields (from a schema or dataclass)."""
28+
restricted = {}
29+
for _field in schema_fields:
30+
min_compat, max_compat = get_restricted_field(_field)
31+
if min_compat or max_compat:
32+
restricted[_field.name] = (min_compat, max_compat)
33+
34+
return restricted
35+
36+
37+
@cached
38+
def get_incompatible_fields(schema_fields: List[Field], package_version: Version) -> Optional[Dict[str, tuple]]:
39+
"""Get a list of fields that are incompatible with the package version."""
40+
if not schema_fields:
41+
return
42+
43+
incompatible = {}
44+
restricted_fields = get_restricted_fields(schema_fields)
45+
for field_name, values in restricted_fields.items():
46+
min_compat, max_compat = values
47+
48+
if min_compat and package_version < min_compat or max_compat and package_version > max_compat:
49+
incompatible[field_name] = (min_compat, max_compat)
50+
51+
return incompatible

tests/test_all_rules.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,3 +677,22 @@ def test_integration_guide(self):
677677
self.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n'
678678
f'Expected: {note_str}\n\n'
679679
f'Actual: {rule.contents.data.note}')
680+
681+
682+
class TestIncompatibleFields(BaseRuleTest):
683+
"""Test stack restricted fields do not backport beyond allowable limits."""
684+
685+
def test_rule_backports_for_restricted_fields(self):
686+
"""Test that stack restricted fields will not backport to older rule versions."""
687+
invalid_rules = []
688+
689+
for rule in self.all_rules:
690+
invalid = rule.contents.check_restricted_fields_compatibility()
691+
if invalid:
692+
invalid_rules.append(f'{self.rule_str(rule)} {invalid}')
693+
694+
if invalid_rules:
695+
invalid_str = '\n'.join(invalid_rules)
696+
err_msg = 'The following rules have min_stack_versions lower than allowed for restricted fields:\n'
697+
err_msg += invalid_str
698+
self.fail(err_msg)

0 commit comments

Comments
 (0)