Skip to content

Commit 5c0fed8

Browse files
committed
allow for named custom array types
1 parent f13d30e commit 5c0fed8

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

schema_salad/avro/schema.py

+48-4
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,47 @@ def items(self) -> Schema:
429429
return cast(Schema, self.get_prop("items"))
430430

431431

432+
class NamedArraySchema(NamedSchema):
433+
"""Avro named array schema class."""
434+
435+
def __init__(
436+
self,
437+
items: JsonDataType,
438+
names: Names,
439+
name: str,
440+
namespace: Optional[str] = None,
441+
doc: Optional[Union[str, list[str]]] = None,
442+
other_props: Optional[PropsType] = None,
443+
) -> None:
444+
"""Create a NamedArraySchema object."""
445+
# Call parent ctor
446+
NamedSchema.__init__(self, "array", name, namespace, names, other_props)
447+
# Add class members
448+
449+
if names is None:
450+
raise SchemaParseException("Must provide Names.")
451+
if isinstance(items, str) and names.has_name(items, None):
452+
items_schema = cast(Schema, names.get_name(items, None))
453+
else:
454+
try:
455+
items_schema = make_avsc_object(items, names)
456+
except Exception as err:
457+
raise SchemaParseException(
458+
f"Items schema ({items}) not a valid Avro schema: {err}. "
459+
f"Known names: {list(names.names.keys())})."
460+
) from err
461+
462+
self.set_prop("items", items_schema)
463+
if doc is not None:
464+
self.set_prop("doc", doc)
465+
466+
# read-only properties
467+
@property
468+
def items(self) -> Schema:
469+
"""Avro schema describing the array items' type."""
470+
return cast(Schema, self.get_prop("items"))
471+
472+
432473
class MapSchema(Schema):
433474
"""Avro map schema class."""
434475

@@ -740,6 +781,11 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
740781
if atype in VALID_TYPES:
741782
if atype == "array":
742783
items = json_data.get("items")
784+
if "name" in json_data and json_data["name"]:
785+
name = json_data["name"]
786+
namespace = json_data.get("namespace", names.default_namespace)
787+
doc = json_data.get("doc")
788+
return NamedArraySchema(items, names, name, namespace, doc, other_props)
743789
return ArraySchema(items, names, other_props)
744790
elif atype == "map":
745791
values = json_data.get("values")
@@ -748,8 +794,7 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
748794
namespace = json_data.get("namespace", names.default_namespace)
749795
doc = json_data.get("doc")
750796
return NamedMapSchema(values, names, name, namespace, doc, other_props)
751-
else:
752-
return MapSchema(values, names, other_props)
797+
return MapSchema(values, names, other_props)
753798
elif atype == "union":
754799
schemas = json_data.get("names")
755800
if not isinstance(schemas, list):
@@ -761,8 +806,7 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
761806
namespace = json_data.get("namespace", names.default_namespace)
762807
doc = json_data.get("doc")
763808
return NamedUnionSchema(schemas, names, name, namespace, doc)
764-
else:
765-
return UnionSchema(schemas, names)
809+
return UnionSchema(schemas, names)
766810
if atype is None:
767811
raise SchemaParseException(f'No "type" property: {json_data}')
768812
raise SchemaParseException(f"Undefined type: {atype}")

schema_salad/validate.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def friendly(v: Any) -> Any:
9393
"""Format an Avro schema into a pretty-printed representation."""
9494
if isinstance(v, avro.schema.NamedSchema):
9595
return avro_shortname(v.name)
96-
if isinstance(v, avro.schema.ArraySchema):
96+
if isinstance(v, (avro.schema.ArraySchema, avro.schema.NamedArraySchema)):
9797
return f"array of <{friendly(v.items)}>"
9898
if isinstance(v, (avro.schema.MapSchema, avro.schema.NamedMapSchema)):
9999
return f"map of <{friendly(v.values)}>"
@@ -208,7 +208,7 @@ def validate_ex(
208208
)
209209
)
210210
return False
211-
if isinstance(expected_schema, avro.schema.ArraySchema):
211+
if isinstance(expected_schema, (avro.schema.ArraySchema, avro.schema.NamedArraySchema)):
212212
if isinstance(datum, MutableSequence):
213213
for i, d in enumerate(datum):
214214
try:
@@ -259,7 +259,7 @@ def validate_ex(
259259
errors: list[SchemaSaladException] = []
260260
checked = []
261261
for s in expected_schema.schemas:
262-
if isinstance(datum, MutableSequence) and not isinstance(s, avro.schema.ArraySchema):
262+
if isinstance(datum, MutableSequence) and not isinstance(s, (avro.schema.ArraySchema, avro.schema.NamedArraySchema)):
263263
continue
264264
if isinstance(datum, MutableMapping) and not isinstance(
265265
s, (avro.schema.RecordSchema, avro.schema.MapSchema, avro.schema.NamedMapSchema)
@@ -269,6 +269,7 @@ def validate_ex(
269269
s,
270270
(
271271
avro.schema.ArraySchema,
272+
avro.schema.NamedArraySchema,
272273
avro.schema.RecordSchema,
273274
avro.schema.MapSchema,
274275
avro.schema.NamedMapSchema,

0 commit comments

Comments
 (0)