|
| 1 | +# --------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# --------------------------------------------------------- |
| 4 | + |
| 5 | +# pylint: disable=unused-argument |
| 6 | + |
| 7 | +from marshmallow import fields, post_load |
| 8 | + |
| 9 | +from azure.ai.ml._schema.assets.data import DataSchema |
| 10 | +from azure.ai.ml._schema.core.fields import ArmVersionedStr, LocalPathField, NestedField, StringTransformedEnum, UnionField |
| 11 | +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta |
| 12 | +from azure.ai.ml._schema.job.input_output_entry import generate_datastore_property |
| 13 | +from azure.ai.ml._utils._experimental import experimental |
| 14 | +from azure.ai.ml.constants._common import AssetTypes, AzureMLResourceType, InputOutputModes |
| 15 | + |
| 16 | + |
| 17 | +# FROM: azure.ai.ml._schema.job.input_output_entry |
| 18 | +def generate_path_property(azureml_type, **kwargs): |
| 19 | + return UnionField( |
| 20 | + [ |
| 21 | + ArmVersionedStr(azureml_type=azureml_type), |
| 22 | + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), |
| 23 | + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), |
| 24 | + LocalPathField(pattern=r"^file:.*"), |
| 25 | + LocalPathField( |
| 26 | + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", |
| 27 | + ), |
| 28 | + ], |
| 29 | + is_strict=True, |
| 30 | + **kwargs, |
| 31 | + ) |
| 32 | + |
| 33 | + |
| 34 | +class DataIndexTypes: |
| 35 | + """DataIndexTypes is an enumeration of values for the types out indexes which can be written to by DataIndex.""" |
| 36 | + |
| 37 | + ACS = "acs" |
| 38 | + """Azure Cognitive Search index type.""" |
| 39 | + FAISS = "faiss" |
| 40 | + """Faiss index type.""" |
| 41 | + |
| 42 | + |
| 43 | +class CitationRegexSchema(metaclass=PatchedSchemaMeta): |
| 44 | + match_pattern = fields.Str( |
| 45 | + required=True, |
| 46 | + metadata={"description": "Regex to match citation in the citation_url + input file path. e.g. '\\1/\\2'"}, |
| 47 | + ) |
| 48 | + replacement_pattern = fields.Str( |
| 49 | + required=True, |
| 50 | + metadata={"description": r"Replacement string for citation. e.g. '(.*)/articles/(.*)(\.[^.]+)$'"}, |
| 51 | + ) |
| 52 | + |
| 53 | + @post_load |
| 54 | + def make(self, data, **kwargs): |
| 55 | + from azure.ai.generative.index._dataindex.entities.data_index import CitationRegex |
| 56 | + |
| 57 | + return CitationRegex(**data) |
| 58 | + |
| 59 | + |
| 60 | +class InputDataSchema(metaclass=PatchedSchemaMeta): |
| 61 | + mode = StringTransformedEnum( |
| 62 | + allowed_values=[ |
| 63 | + InputOutputModes.RO_MOUNT, |
| 64 | + InputOutputModes.RW_MOUNT, |
| 65 | + InputOutputModes.DOWNLOAD, |
| 66 | + ], |
| 67 | + required=False, |
| 68 | + ) |
| 69 | + type = StringTransformedEnum( |
| 70 | + allowed_values=[ |
| 71 | + AssetTypes.URI_FILE, |
| 72 | + AssetTypes.URI_FOLDER, |
| 73 | + ] |
| 74 | + ) |
| 75 | + path = generate_path_property(azureml_type=AzureMLResourceType.DATA) |
| 76 | + datastore = generate_datastore_property() |
| 77 | + |
| 78 | + @post_load |
| 79 | + def make(self, data, **kwargs): |
| 80 | + from azure.ai.ml.entities import Data |
| 81 | + |
| 82 | + return Data(**data) |
| 83 | + |
| 84 | + |
| 85 | +class InputMLTableSchema(metaclass=PatchedSchemaMeta): |
| 86 | + mode = StringTransformedEnum( |
| 87 | + allowed_values=[ |
| 88 | + InputOutputModes.EVAL_MOUNT, |
| 89 | + InputOutputModes.EVAL_DOWNLOAD, |
| 90 | + ], |
| 91 | + required=False, |
| 92 | + ) |
| 93 | + type = StringTransformedEnum(allowed_values=[AssetTypes.MLTABLE]) |
| 94 | + path = generate_path_property(azureml_type=AzureMLResourceType.DATA) |
| 95 | + datastore = generate_datastore_property() |
| 96 | + |
| 97 | + @post_load |
| 98 | + def make(self, data, **kwargs): |
| 99 | + from azure.ai.ml.entities import Data |
| 100 | + |
| 101 | + return Data(**data) |
| 102 | + |
| 103 | + |
| 104 | +class IndexSourceSchema(metaclass=PatchedSchemaMeta): |
| 105 | + input_data = UnionField( |
| 106 | + [NestedField(InputDataSchema), NestedField(InputMLTableSchema)], |
| 107 | + required=True, |
| 108 | + allow_none=False, |
| 109 | + metadata={"description": "Input Data to index files from. MLTable type inputs will use `mode: eval_mount`."}, |
| 110 | + ) |
| 111 | + input_glob = fields.Str( |
| 112 | + required=False, |
| 113 | + metadata={ |
| 114 | + "description": "Glob pattern to filter files from input_data. If not specified, all files will be indexed." |
| 115 | + }, |
| 116 | + ) |
| 117 | + chunk_size = fields.Int( |
| 118 | + required=False, |
| 119 | + allow_none=False, |
| 120 | + metadata={"description": "Maximum number of tokens to put in each chunk."}, |
| 121 | + ) |
| 122 | + chunk_overlap = fields.Int( |
| 123 | + required=False, |
| 124 | + allow_none=False, |
| 125 | + metadata={"description": "Number of tokens to overlap between chunks."}, |
| 126 | + ) |
| 127 | + citation_url = fields.Str( |
| 128 | + required=False, |
| 129 | + metadata={"description": "Base URL to join with file paths to create full source file URL for chunk metadata."}, |
| 130 | + ) |
| 131 | + citation_url_replacement_regex = NestedField( |
| 132 | + CitationRegexSchema, |
| 133 | + required=False, |
| 134 | + metadata={ |
| 135 | + "description": "Regex match and replacement patterns for citation url. Useful if the paths in `input_data` " |
| 136 | + "don't match the desired citation format." |
| 137 | + }, |
| 138 | + ) |
| 139 | + |
| 140 | + @post_load |
| 141 | + def make(self, data, **kwargs): |
| 142 | + from azure.ai.generative.index._dataindex.entities.data_index import IndexSource |
| 143 | + |
| 144 | + return IndexSource(**data) |
| 145 | + |
| 146 | + |
| 147 | +class EmbeddingSchema(metaclass=PatchedSchemaMeta): |
| 148 | + model = fields.Str( |
| 149 | + required=True, |
| 150 | + allow_none=False, |
| 151 | + metadata={ |
| 152 | + "description": "The model to use to embed data. E.g. 'hugging_face://model/sentence-transformers/" |
| 153 | + "all-mpnet-base-v2' or 'azure_open_ai://deployment/{{deployment_name}}/model/{{model_name}}'" |
| 154 | + }, |
| 155 | + ) |
| 156 | + connection = fields.Str( |
| 157 | + required=False, |
| 158 | + metadata={ |
| 159 | + "description": "Connection reference to use for embedding model information, " |
| 160 | + "only needed for hosted embeddings models (such as Azure OpenAI)." |
| 161 | + }, |
| 162 | + ) |
| 163 | + cache_path = generate_path_property( |
| 164 | + azureml_type=AzureMLResourceType.DATASTORE, |
| 165 | + required=False, |
| 166 | + metadata={ |
| 167 | + "description": "Folder containing previously generated embeddings. " |
| 168 | + "Should be parent folder of the 'embeddings' output path used for for this component. " |
| 169 | + "Will compare input data to existing embeddings and only embed changed/new data, " |
| 170 | + "reusing existing chunks." |
| 171 | + }, |
| 172 | + ) |
| 173 | + |
| 174 | + @post_load |
| 175 | + def make(self, data, **kwargs): |
| 176 | + from azure.ai.generative.index._dataindex.entities.data_index import Embedding |
| 177 | + |
| 178 | + return Embedding(**data) |
| 179 | + |
| 180 | + |
| 181 | +class IndexStoreSchema(metaclass=PatchedSchemaMeta): |
| 182 | + type = StringTransformedEnum( |
| 183 | + allowed_values=[ |
| 184 | + DataIndexTypes.ACS, |
| 185 | + DataIndexTypes.FAISS, |
| 186 | + ], |
| 187 | + metadata={"description": "The type of index to write to. Currently supported types are 'acs' and 'faiss'."}, |
| 188 | + ) |
| 189 | + name = fields.Str( |
| 190 | + required=False, |
| 191 | + metadata={"description": "Name of the index to write to. If not specified, a name will be generated."}, |
| 192 | + ) |
| 193 | + connection = fields.Str( |
| 194 | + required=False, |
| 195 | + metadata={ |
| 196 | + "description": "Connection reference to use for index information, " |
| 197 | + "only needed for hosted indexes (such as Azure Cognitive Search)." |
| 198 | + }, |
| 199 | + ) |
| 200 | + config = fields.Dict( |
| 201 | + required=False, |
| 202 | + metadata={ |
| 203 | + "description": "Configuration for the index. Primary use is to configure Azure Cognitive Search specific settings." |
| 204 | + "Such as custom `field_mapping` for known field types." |
| 205 | + } |
| 206 | + ) |
| 207 | + |
| 208 | + @post_load |
| 209 | + def make(self, data, **kwargs): |
| 210 | + from azure.ai.generative.index._dataindex.entities.data_index import IndexStore |
| 211 | + |
| 212 | + return IndexStore(**data) |
| 213 | + |
| 214 | + |
| 215 | +@experimental |
| 216 | +class DataIndexSchema(DataSchema): |
| 217 | + source = NestedField(IndexSourceSchema, required=True, allow_none=False) |
| 218 | + embedding = NestedField(EmbeddingSchema, required=True, allow_none=False) |
| 219 | + index = NestedField(IndexStoreSchema, required=True, allow_none=False) |
| 220 | + incremental_update = fields.Bool() |
| 221 | + |
| 222 | + @post_load |
| 223 | + def make(self, data, **kwargs): |
| 224 | + from azure.ai.generative.index._dataindex.entities.data_index import DataIndex |
| 225 | + |
| 226 | + return DataIndex(**data) |
0 commit comments