diff --git a/src/mcp_youtube_transcript/server.py b/src/mcp_youtube_transcript/server.py index b98b278..d79f41f 100644 --- a/src/mcp_youtube_transcript/server.py +++ b/src/mcp_youtube_transcript/server.py @@ -39,11 +39,14 @@ def get_transcript( ) -> str: """Retrieves the transcript of a YouTube video.""" parsed_url = urlparse(url) - query_params = parse_qs(parsed_url.query) - video_id = query_params.get("v", [None])[0] - if video_id is None: - raise ValueError(f"couldn't find a video ID from the provided URL: {url}.") + if parsed_url.hostname == "youtu.be": + video_id = parsed_url.path.lstrip("/") + else: + q = parse_qs(parsed_url.query).get("v") + if q is None: + raise ValueError(f"couldn't find a video ID from the provided URL: {url}.") + video_id = q[0] if lang == "en": languages = ["en"] diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 2b551ea..fb16a99 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -101,3 +101,19 @@ async def test_get_transcript_invalid_url(mcp_client_session: ClientSession) -> async def test_get_transcript_not_found(mcp_client_session: ClientSession) -> None: res = await mcp_client_session.call_tool("get_transcript", arguments={"url": "https//www.youtube.com/watch?v=a"}) assert res.isError + + +@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping this test on CI") +@pytest.mark.anyio +async def test_get_transcript_with_short_url(mcp_client_session: ClientSession) -> None: + video_id = "LPZh9BOjkQs" + + expect = "\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id))) + + res = await mcp_client_session.call_tool( + "get_transcript", + arguments={"url": f"https://youtu.be/{video_id}"}, + ) + assert isinstance(res.content[0], TextContent) + assert res.content[0].text == expect + assert not res.isError