Skip to content
This repository was archived by the owner on May 27, 2025. It is now read-only.

Multi-filetype Support via Markitdown #269

Merged
merged 20 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 77 additions & 58 deletions backend/graphrag_app/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import asyncio
import hashlib
import os
import re
import traceback
Expand All @@ -14,20 +15,25 @@
Depends,
HTTPException,
UploadFile,
status,
)
from markitdown import MarkItDown, StreamInfo

from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import (
BaseResponse,
StorageNameList,
)
from graphrag_app.utils.common import (
check_cache,
create_cache,
delete_cosmos_container_item_if_exist,
delete_storage_container_if_exist,
get_blob_container_client,
get_cosmos_container_store_client,
sanitize_name,
subscription_key_check,
update_cache,
)

data_route = APIRouter(
Expand All @@ -42,7 +48,7 @@
"",
summary="Get list of data containers.",
response_model=StorageNameList,
responses={200: {"model": StorageNameList}},
responses={status.HTTP_200_OK: {"model": StorageNameList}},
)
async def get_all_data_containers():
"""
Expand All @@ -67,56 +73,66 @@ async def get_all_data_containers():
return StorageNameList(storage_name=items)


async def upload_file_async(
async def upload_file(
upload_file: UploadFile, container_client: ContainerClient, overwrite: bool = True
) -> None:
):
"""
Asynchronously upload a file to the specified blob container.
Silently ignore errors that occur when overwrite=False.
Convert and upload a file to a specified blob container.

Returns a list of objects where each object will have one of the following types:
* Tuple[str, str] - a tuple of (filename, file_hash) for successful uploads
* Tuple[str, None] - a tuple of (filename, None) for failed uploads or
* None for skipped files
"""
blob_client = container_client.get_blob_client(upload_file.filename)
filename = upload_file.filename
extension = os.path.splitext(filename)[1]
converted_filename = filename + ".txt"
converted_blob_client = container_client.get_blob_client(converted_filename)

with upload_file.file as file_stream:
try:
await blob_client.upload_blob(file_stream, overwrite=overwrite)
file_hash = hashlib.sha256(file_stream.read()).hexdigest()
if not await check_cache(file_hash, container_client):
# extract text from file using MarkItDown
md = MarkItDown()
stream_info = StreamInfo(
extension=extension,
)
file_stream._file.seek(0)
file_stream = file_stream._file
result = md.convert_stream(
stream=file_stream,
stream_info=stream_info,
)

# remove illegal unicode characters and upload to blob storage
cleaned_result = _clean_output(result.text_content)
await converted_blob_client.upload_blob(
cleaned_result, overwrite=overwrite
)

# return tuple of (filename, file_hash) to indicate success
return (filename, file_hash)
except Exception:
pass

# if any exception occurs, return a tuple of (filename, None) to indicate conversion/upload failure
return (upload_file.filename, None)

class Cleaner:
def __init__(self, file):
self.file = file
self.name = file.name
self.changes = 0

def clean(self, val, replacement=""):
# fmt: off
_illegal_xml_chars_RE = re.compile(
def _clean_output(val: str, replacement: str = ""):
"""Removes unicode characters that are invalid XML characters (not valid for graphml files at least)."""
# fmt: off
_illegal_xml_chars_RE = re.compile(
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
)
# fmt: on
self.changes += len(_illegal_xml_chars_RE.findall(val))
return _illegal_xml_chars_RE.sub(replacement, val)

def read(self, n):
return self.clean(self.file.read(n).decode()).encode(
encoding="utf-8", errors="strict"
)

def name(self):
return self.file.name

def __enter__(self):
return self

def __exit__(self, *args):
self.file.close()
# fmt: on
return _illegal_xml_chars_RE.sub(replacement, val)


@data_route.post(
"",
summary="Upload data to a data storage container",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
responses={status.HTTP_201_CREATED: {"model": BaseResponse}},
)
async def upload_files(
files: List[UploadFile],
Expand All @@ -125,36 +141,33 @@ async def upload_files(
overwrite: bool = True,
):
"""
Create a Azure Storage container and upload files to it.

Args:
files (List[UploadFile]): A list of files to be uploaded.
storage_name (str): The name of the Azure Blob Storage container to which files will be uploaded.
overwrite (bool): Whether to overwrite existing files with the same name. Defaults to True. If False, files that already exist will be skipped.

Returns:
BaseResponse: An instance of the BaseResponse model with a status message indicating the result of the upload.

Raises:
HTTPException: If the container name is invalid or if any error occurs during the upload process.
Create a Azure Storage container (if needed) and upload files. Multiple file types are supported, including pdf, powerpoint, word, excel, html, csv, json, xml, etc.
The complete set of supported file types can be found in the MarkItDown (https://github.com/microsoft/markitdown) library.
"""
try:
# clean files - remove illegal XML characters
files = [UploadFile(Cleaner(f.file), filename=f.filename) for f in files]

# upload files in batches of 1000 to avoid exceeding Azure Storage API limits
# create the initial cache if it doesn't exist
blob_container_client = await get_blob_container_client(
sanitized_container_name
)
batch_size = 1000
await create_cache(blob_container_client)

# process file uploads in batches to avoid exceeding Azure Storage API limits
processing_errors = []
batch_size = 100
num_batches = ceil(len(files) / batch_size)
for i in range(num_batches):
batch_files = files[i * batch_size : (i + 1) * batch_size]
tasks = [
upload_file_async(file, blob_container_client, overwrite)
upload_file(file, blob_container_client, overwrite)
for file in batch_files
]
await asyncio.gather(*tasks)
upload_results = await asyncio.gather(*tasks)
successful_uploads = [r for r in upload_results if r and r[1] is not None]
# update the file cache with successful uploads
await update_cache(successful_uploads, blob_container_client)
# collect failed uploads
failed_uploads = [r[0] for r in upload_results if r and r[1] is None]
processing_errors.extend(failed_uploads)

# update container-store entry in cosmosDB once upload process is successful
cosmos_container_store_client = get_cosmos_container_store_client()
Expand All @@ -163,17 +176,23 @@ async def upload_files(
"human_readable_name": container_name,
"type": "data",
})
return BaseResponse(status="File upload successful.")

if len(processing_errors) > 0:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Error uploading files: {processing_errors}.",
)
return BaseResponse(status="Success.")
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Error uploading files.",
cause=e,
stack=traceback.format_exc(),
details={"files": [f.filename for f in files]},
details={"files": processing_errors},
)
raise HTTPException(
status_code=500,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading files to container '{container_name}'.",
)

Expand All @@ -182,7 +201,7 @@ async def upload_files(
"/{container_name}",
summary="Delete a data storage container",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
responses={status.HTTP_200_OK: {"model": BaseResponse}},
)
async def delete_files(
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
Expand Down
2 changes: 2 additions & 0 deletions backend/graphrag_app/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
APIRouter,
Depends,
HTTPException,
status,
)
from fastapi.responses import StreamingResponse

Expand All @@ -31,6 +32,7 @@
"/graphml/{container_name}",
summary="Retrieve a GraphML file of the knowledge graph",
response_description="GraphML file successfully downloaded",
status_code=status.HTTP_200_OK,
)
async def get_graphml_file(
container_name, sanitized_container_name: str = Depends(sanitize_name)
Expand Down
20 changes: 12 additions & 8 deletions backend/graphrag_app/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Depends,
HTTPException,
UploadFile,
status,
)
from kubernetes import (
client as kubernetes_client,
Expand Down Expand Up @@ -49,7 +50,7 @@
"",
summary="Build an index",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
responses={status.HTTP_202_ACCEPTED: {"model": BaseResponse}},
)
async def schedule_index_job(
storage_container_name: str,
Expand All @@ -71,7 +72,7 @@ async def schedule_index_job(
sanitized_storage_container_name
).exists():
raise HTTPException(
status_code=500,
status_code=status.HTTP_412_PRECONDITION_FAILED,
detail=f"Storage container '{storage_container_name}' does not exist",
)

Expand Down Expand Up @@ -101,7 +102,7 @@ async def schedule_index_job(
PipelineJobState(existing_job.status) == PipelineJobState.RUNNING
):
raise HTTPException(
status_code=202, # request has been accepted for processing but is not complete.
status_code=status.HTTP_425_TOO_EARLY, # request has been accepted for processing but is not complete.
detail=f"Index '{index_container_name}' already exists and has not finished building.",
)
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
Expand Down Expand Up @@ -142,7 +143,7 @@ async def schedule_index_job(
"",
summary="Get all index names",
response_model=IndexNameList,
responses={200: {"model": IndexNameList}},
responses={status.HTTP_200_OK: {"model": IndexNameList}},
)
async def get_all_index_names(
container_store_client=Depends(get_cosmos_container_store_client),
Expand Down Expand Up @@ -218,7 +219,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
"/{container_name}",
summary="Delete a specified index",
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
responses={status.HTTP_200_OK: {"model": BaseResponse}},
)
async def delete_index(
container_name: str,
Expand Down Expand Up @@ -257,7 +258,8 @@ async def delete_index(
details={"container": container_name},
)
raise HTTPException(
status_code=500, detail=f"Error deleting '{container_name}'."
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error deleting '{container_name}'.",
)

return BaseResponse(status="Success")
Expand All @@ -267,6 +269,7 @@ async def delete_index(
"/status/{container_name}",
summary="Track the status of an indexing job",
response_model=IndexStatusResponse,
status_code=status.HTTP_200_OK,
)
async def get_index_status(
container_name: str, sanitized_container_name: str = Depends(sanitize_name)
Expand All @@ -275,7 +278,7 @@ async def get_index_status(
if pipelinejob.item_exist(sanitized_container_name):
pipeline_job = pipelinejob.load_item(sanitized_container_name)
return IndexStatusResponse(
status_code=200,
status_code=status.HTTP_200_OK,
index_name=pipeline_job.human_readable_index_name,
storage_name=pipeline_job.human_readable_storage_name,
status=pipeline_job.status.value,
Expand All @@ -284,5 +287,6 @@ async def get_index_status(
)
else:
raise HTTPException(
status_code=404, detail=f"'{container_name}' does not exist."
status_code=status.HTTP_404_NOT_FOUND,
detail=f"'{container_name}' does not exist.",
)
2 changes: 2 additions & 0 deletions backend/graphrag_app/api/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
APIRouter,
Depends,
HTTPException,
status,
)
from graphrag.config.create_graphrag_config import create_graphrag_config

Expand All @@ -27,6 +28,7 @@
"/prompts",
summary="Generate custom graphrag prompts based on user-provided data.",
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
status_code=status.HTTP_200_OK,
)
async def generate_prompts(
container_name: str,
Expand Down
9 changes: 5 additions & 4 deletions backend/graphrag_app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
APIRouter,
Depends,
HTTPException,
status,
)
from graphrag.api.query import global_search, local_search
from graphrag.config.create_graphrag_config import create_graphrag_config
Expand Down Expand Up @@ -42,7 +43,7 @@
summary="Perform a global search across the knowledge graph index",
description="The global query method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole.",
response_model=GraphResponse,
responses={200: {"model": GraphResponse}},
responses={status.HTTP_200_OK: {"model": GraphResponse}},
)
async def global_query(request: GraphRequest):
# this is a slightly modified version of the graphrag.query.cli.run_global_search method
Expand All @@ -51,7 +52,7 @@ async def global_query(request: GraphRequest):

if not _is_index_complete(sanitized_index_name):
raise HTTPException(
status_code=500,
status_code=status.HTTP_425_TOO_EARLY,
detail=f"{index_name} not ready for querying.",
)

Expand Down Expand Up @@ -122,15 +123,15 @@ async def global_query(request: GraphRequest):
summary="Perform a local search across the knowledge graph index.",
description="The local query method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).",
response_model=GraphResponse,
responses={200: {"model": GraphResponse}},
responses={status.HTTP_200_OK: {"model": GraphResponse}},
)
async def local_query(request: GraphRequest):
index_name = request.index_name
sanitized_index_name = sanitize_name(index_name)

if not _is_index_complete(sanitized_index_name):
raise HTTPException(
status_code=500,
status_code=status.HTTP_425_TOO_EARLY,
detail=f"{index_name} not ready for querying.",
)

Expand Down
Loading
Loading