Skip to content

Add migration script for Azure Cosmos DB, old container to new container #2442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions scripts/cosmosdb_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
A migration script to migrate data from CosmosDB to a new format.
The old schema:
id: str
entra_oid: str
title: str
timestamp: int
answers: list of 2-item list of str, dict

The new schema has two item types in the same container:
For session items:
id: str
session_id: str
entra_oid: str
title: str
timestamp: int
type: str (always "session")
version: str (always "cosmosdb-v2")

For message_pair items:
id: str
session_id: str
entra_oid: str
type: str (always "message_pair")
version: str (always "cosmosdb-v2")
question: str
response: dict
"""

import os

from azure.cosmos.aio import CosmosClient
from azure.identity.aio import AzureDeveloperCliCredential

from load_azd_env import load_azd_env


class CosmosDBMigrator:
"""
Migrator class for CosmosDB data migration.
"""

def __init__(self, cosmos_account, database_name, credential=None):
"""
Initialize the migrator with CosmosDB account and database.

Args:
cosmos_account: CosmosDB account name
database_name: Database name
credential: Azure credential, defaults to AzureDeveloperCliCredential
"""
self.cosmos_account = cosmos_account
self.database_name = database_name
self.credential = credential or AzureDeveloperCliCredential()
self.client = None
self.database = None
self.old_container = None
self.new_container = None

async def connect(self):
"""
Connect to CosmosDB and initialize containers.
"""
self.client = CosmosClient(
url=f"https://{self.cosmos_account}.documents.azure.com:443/", credential=self.credential
)
self.database = self.client.get_database_client(self.database_name)
self.old_container = self.database.get_container_client("chat-history")
self.new_container = self.database.get_container_client("chat-history-v2")
try:
await self.old_container.read()
except Exception:
raise ValueError(f"Old container {self.old_container.id} does not exist")
try:
await self.new_container.read()
except Exception:
raise ValueError(f"New container {self.new_container.id} does not exist")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does migration script need to create the container?


async def migrate(self):
"""
Migrate data from old schema to new schema.
"""
if not self.client:
await self.connect()
if not self.old_container or not self.new_container:
raise ValueError("Containers do not exist")

query_results = self.old_container.query_items(query="SELECT * FROM c")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We select * but we are only copying specific items. IMO either:

  1. Fetch all columns and copy all columns
  2. Only select the columns you want to copy


item_migration_count = 0
async for page in query_results.by_page():
async for old_item in page:
batch_operations = []
# Build session item
session_item = {
"id": old_item["id"],
"version": "cosmosdb-v2",
"session_id": old_item["id"],
"entra_oid": old_item["entra_oid"],
"title": old_item.get("title"),
"timestamp": old_item.get("timestamp"),
"type": "session",
}
batch_operations.append(("upsert", (session_item,)))

# Build message_pair
answers = old_item.get("answers", [])
for idx, answer in enumerate(answers):
question = answer[0]
response = answer[1]
message_pair = {
"id": f"{old_item['id']}-{idx}",
"version": "cosmosdb-v2",
"session_id": old_item["id"],
"entra_oid": old_item["entra_oid"],
"type": "message_pair",
"question": question,
"response": response,
"order": idx,
"timestamp": None,
}
batch_operations.append(("upsert", (message_pair,)))

# Execute the batch using partition key [entra_oid, session_id]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment is wrong? we are using [entra_oid, id]

await self.new_container.execute_item_batch(
batch_operations=batch_operations, partition_key=[old_item["entra_oid"], old_item["id"]]
)
item_migration_count += 1
print(f"Total items migrated: {item_migration_count}")

async def close(self):
"""
Close the CosmosDB client.
"""
if self.client:
await self.client.close()


async def migrate_cosmosdb_data():
"""
Legacy function for backward compatibility.
Migrate data from CosmosDB to a new format.
"""
USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true"
if not USE_CHAT_HISTORY_COSMOS:
raise ValueError("USE_CHAT_HISTORY_COSMOS must be set to true")
AZURE_COSMOSDB_ACCOUNT = os.environ["AZURE_COSMOSDB_ACCOUNT"]
AZURE_CHAT_HISTORY_DATABASE = os.environ["AZURE_CHAT_HISTORY_DATABASE"]

migrator = CosmosDBMigrator(AZURE_COSMOSDB_ACCOUNT, AZURE_CHAT_HISTORY_DATABASE)
try:
await migrator.migrate()
finally:
await migrator.close()


if __name__ == "__main__":
load_azd_env()

import asyncio

asyncio.run(migrate_cosmosdb_data())
192 changes: 192 additions & 0 deletions tests/test_cosmosdb_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from scripts.cosmosdb_migration import CosmosDBMigrator, migrate_cosmosdb_data

# Sample old format item
TEST_OLD_ITEM = {
"id": "123",
"entra_oid": "OID_X",
"title": "This is a test message",
"timestamp": 123456789,
"answers": [
[
"What does a Product Manager do?",
{
"delta": {"role": "assistant"},
"session_state": "143c0240-b2ee-4090-8e90-2a1c58124894",
"message": {
"content": "A Product Manager is responsible for product strategy and execution.",
"role": "assistant",
},
},
],
[
"What about a Software Engineer?",
{
"delta": {"role": "assistant"},
"session_state": "243c0240-b2ee-4090-8e90-2a1c58124894",
"message": {
"content": "A Software Engineer writes code to create applications.",
"role": "assistant",
},
},
],
],
}


class MockAsyncPageIterator:
"""Helper class to mock an async page from CosmosDB"""

def __init__(self, items):
self.items = items

def __aiter__(self):
return self

async def __anext__(self):
if not self.items:
raise StopAsyncIteration
return self.items.pop(0)


class MockCosmosDBResultsIterator:
"""Helper class to mock a paginated query result from CosmosDB"""

def __init__(self, data=[]):
self.data = data
self.continuation_token = None

def by_page(self, continuation_token=None):
"""Return a paged iterator"""
self.continuation_token = "next_token" if not continuation_token else continuation_token + "_next"
# Return an async iterator that contains pages
return MockPagesAsyncIterator(self.data)


class MockPagesAsyncIterator:
"""Helper class to mock an iterator of pages"""

def __init__(self, data):
self.data = data
self.continuation_token = "next_token"

def __aiter__(self):
return self

async def __anext__(self):
if not self.data:
raise StopAsyncIteration
# Return a page, which is an async iterator of items
return MockAsyncPageIterator([self.data.pop(0)])


@pytest.mark.asyncio
async def test_migrate_method():
"""Test the migrate method of CosmosDBMigrator"""
# Create mock objects
mock_container = MagicMock()
mock_database = MagicMock()
mock_client = MagicMock()

# Set up the query_items mock to return our test item
mock_container.query_items.return_value = MockCosmosDBResultsIterator([TEST_OLD_ITEM])

# Set up execute_item_batch as a spy to capture calls
execute_batch_mock = AsyncMock()
mock_container.execute_item_batch = execute_batch_mock

# Set up the database mock to return our container mocks
mock_database.get_container_client.side_effect = lambda container_name: mock_container

# Set up the client mock
mock_client.get_database_client.return_value = mock_database

# Create the migrator with our mocks
migrator = CosmosDBMigrator("dummy_account", "dummy_db")
migrator.client = mock_client
migrator.database = mock_database
migrator.old_container = mock_container
migrator.new_container = mock_container

# Call the migrate method
await migrator.migrate()

# Verify query_items was called with the right parameters
mock_container.query_items.assert_called_once_with(query="SELECT * FROM c")

# Verify execute_item_batch was called
execute_batch_mock.assert_called_once()

# Extract the arguments from the call
call_args = execute_batch_mock.call_args[1]
batch_operations = call_args["batch_operations"]
partition_key = call_args["partition_key"]

# Verify the partition key
assert partition_key == ["OID_X", "123"]

# We should have 3 operations: 1 for session and 2 for message pairs
assert len(batch_operations) == 3

# Verify session item
session_operation = batch_operations[0]
assert session_operation[0] == "upsert"
session_item = session_operation[1][0]
assert session_item["id"] == "123"
assert session_item["session_id"] == "123"
assert session_item["entra_oid"] == "OID_X"
assert session_item["title"] == "This is a test message"
assert session_item["timestamp"] == 123456789
assert session_item["type"] == "session"
assert session_item["version"] == "cosmosdb-v2"

# Verify first message pair
message1_operation = batch_operations[1]
assert message1_operation[0] == "upsert"
message1_item = message1_operation[1][0]
assert message1_item["id"] == "123-0"
assert message1_item["session_id"] == "123"
assert message1_item["entra_oid"] == "OID_X"
assert message1_item["question"] == "What does a Product Manager do?"
assert message1_item["type"] == "message_pair"
assert message1_item["order"] == 0

# Verify second message pair
message2_operation = batch_operations[2]
assert message2_operation[0] == "upsert"
message2_item = message2_operation[1][0]
assert message2_item["id"] == "123-1"
assert message2_item["session_id"] == "123"
assert message2_item["entra_oid"] == "OID_X"
assert message2_item["question"] == "What about a Software Engineer?"
assert message2_item["type"] == "message_pair"
assert message2_item["order"] == 1


@pytest.mark.asyncio
async def test_migrate_cosmosdb_data(monkeypatch):
"""Test the main migrate_cosmosdb_data function"""
with patch.dict(os.environ, clear=True):
monkeypatch.setenv("USE_CHAT_HISTORY_COSMOS", "true")
monkeypatch.setenv("AZURE_COSMOSDB_ACCOUNT", "dummy_account")
monkeypatch.setenv("AZURE_CHAT_HISTORY_DATABASE", "dummy_db")

# Create a mock for the CosmosDBMigrator
with patch("scripts.cosmosdb_migration.CosmosDBMigrator") as mock_migrator_class:
# Set up the mock for the migrator instance
mock_migrator = AsyncMock()
mock_migrator_class.return_value = mock_migrator

# Call the function
await migrate_cosmosdb_data()

# Verify the migrator was created with the right parameters
mock_migrator_class.assert_called_once_with("dummy_account", "dummy_db")

# Verify migrate and close were called
mock_migrator.migrate.assert_called_once()
mock_migrator.close.assert_called_once()