diff --git a/spec/unit/webrtc/call.spec.ts b/spec/unit/webrtc/call.spec.ts index bcded3fb78e..8a2e255e811 100644 --- a/spec/unit/webrtc/call.spec.ts +++ b/spec/unit/webrtc/call.spec.ts @@ -82,17 +82,34 @@ class MockRTCPeerConnection { } close() {} getStats() { return []; } + addTrack(track: MockMediaStreamTrack) {return new MockRTCRtpSender(track);} +} + +class MockRTCRtpSender { + constructor(public track: MockMediaStreamTrack) {} + + replaceTrack(track: MockMediaStreamTrack) {this.track = track;} +} + +class MockMediaStreamTrack { + constructor(public readonly id: string, public readonly kind: "audio" | "video", public enabled = true) {} + + stop() {} } class MockMediaStream { constructor( public id: string, + private tracks: MockMediaStreamTrack[] = [], ) {} - getTracks() { return []; } - getAudioTracks() { return [{ enabled: true }]; } - getVideoTracks() { return [{ enabled: true }]; } + getTracks() { return this.tracks; } + getAudioTracks() { return this.tracks.filter((track) => track.kind === "audio"); } + getVideoTracks() { return this.tracks.filter((track) => track.kind === "video"); } addEventListener() {} + removeEventListener() { } + addTrack(track: MockMediaStreamTrack) {this.tracks.push(track);} + removeTrack(track: MockMediaStreamTrack) {this.tracks.splice(this.tracks.indexOf(track), 1);} } class MockMediaDeviceInfo { @@ -102,7 +119,13 @@ class MockMediaDeviceInfo { } class MockMediaHandler { - getUserMediaStream() { return new MockMediaStream("mock_stream_from_media_handler"); } + getUserMediaStream(audio: boolean, video: boolean) { + const tracks = []; + if (audio) tracks.push(new MockMediaStreamTrack("audio_track", "audio")); + if (video) tracks.push(new MockMediaStreamTrack("video_track", "video")); + + return new MockMediaStream("mock_stream_from_media_handler", tracks); + } stopUserMediaStream() {} } @@ -350,7 +373,15 @@ describe('Call', function() { }, }); - call.pushRemoteFeed(new MockMediaStream("remote_stream")); + call.pushRemoteFeed( + new MockMediaStream( + "remote_stream", + [ + new MockMediaStreamTrack("remote_audio_track", "audio"), + new MockMediaStreamTrack("remote_video_track", "video"), + ], + ), + ); const feed = call.getFeeds().find((feed) => feed.stream.id === "remote_stream"); expect(feed?.purpose).toBe(SDPStreamMetadataPurpose.Usermedia); expect(feed?.isAudioMuted()).toBeTruthy(); @@ -396,4 +427,82 @@ describe('Call', function() { expect(client.client.mediaHandler.getUserMediaStream).toHaveBeenNthCalledWith(1, true, true); expect(client.client.mediaHandler.getUserMediaStream).toHaveBeenNthCalledWith(2, true, false); }); + + it("should handle mid-call device changes", async () => { + client.client.mediaHandler.getUserMediaStream = jest.fn().mockReturnValue( + new MockMediaStream( + "stream", [ + new MockMediaStreamTrack("audio_track", "audio"), + new MockMediaStreamTrack("video_track", "video"), + ], + ), + ); + + const callPromise = call.placeVideoCall(); + await client.httpBackend.flush(); + await callPromise; + + await call.onAnswerReceived({ + getContent: () => { + return { + version: 1, + call_id: call.callId, + party_id: 'party_id', + answer: { + sdp: DUMMY_SDP, + }, + }; + }, + }); + + await call.updateLocalUsermediaStream( + new MockMediaStream( + "replacement_stream", + [ + new MockMediaStreamTrack("new_audio_track", "audio"), + new MockMediaStreamTrack("video_track", "video"), + ], + ), + ); + expect(call.localUsermediaStream.id).toBe("stream"); + expect(call.localUsermediaStream.getAudioTracks()[0].id).toBe("new_audio_track"); + expect(call.localUsermediaStream.getVideoTracks()[0].id).toBe("video_track"); + expect(call.usermediaSenders.find((sender) => { + return sender?.track?.kind === "audio"; + }).track.id).toBe("new_audio_track"); + expect(call.usermediaSenders.find((sender) => { + return sender?.track?.kind === "video"; + }).track.id).toBe("video_track"); + }); + + it("should handle upgrade to video call", async () => { + const callPromise = call.placeVoiceCall(); + await client.httpBackend.flush(); + await callPromise; + + await call.onAnswerReceived({ + getContent: () => { + return { + version: 1, + call_id: call.callId, + party_id: 'party_id', + answer: { + sdp: DUMMY_SDP, + }, + [SDPStreamMetadataKey]: {}, + }; + }, + }); + + await call.upgradeCall(false, true); + + expect(call.localUsermediaStream.getAudioTracks()[0].id).toBe("audio_track"); + expect(call.localUsermediaStream.getVideoTracks()[0].id).toBe("video_track"); + expect(call.usermediaSenders.find((sender) => { + return sender?.track?.kind === "audio"; + }).track.id).toBe("audio_track"); + expect(call.usermediaSenders.find((sender) => { + return sender?.track?.kind === "video"; + }).track.id).toBe("video_track"); + }); }); diff --git a/src/client.ts b/src/client.ts index 3ebc0740f98..e77a136985e 100644 --- a/src/client.ts +++ b/src/client.ts @@ -822,7 +822,7 @@ export class MatrixClient extends EventEmitter { protected checkTurnServersIntervalID: number; protected exportedOlmDeviceToImport: IOlmDevice; protected txnCtr = 0; - protected mediaHandler = new MediaHandler(); + protected mediaHandler = new MediaHandler(this); protected pendingEventEncryption = new Map>(); constructor(opts: IMatrixClientCreateOpts) { diff --git a/src/webrtc/call.ts b/src/webrtc/call.ts index da4dc8a4b3e..0ac6414e1fa 100644 --- a/src/webrtc/call.ts +++ b/src/webrtc/call.ts @@ -2,6 +2,7 @@ Copyright 2015, 2016 OpenMarket Ltd Copyright 2017 New Vector Ltd Copyright 2019, 2020 The Matrix.org Foundation C.I.C. +Copyright 2021 - 2022 Šimon Brandner Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -933,29 +934,13 @@ export class MatrixCall extends EventEmitter { if (!this.opponentSupportsSDPStreamMetadata()) return; try { - const upgradeAudio = audio && !this.hasLocalUserMediaAudioTrack; - const upgradeVideo = video && !this.hasLocalUserMediaVideoTrack; - logger.debug(`Upgrading call: audio?=${upgradeAudio} video?=${upgradeVideo}`); - - const stream = await this.client.getMediaHandler().getUserMediaStream(upgradeAudio, upgradeVideo, false); - if (upgradeAudio && upgradeVideo) { - if (this.hasLocalUserMediaAudioTrack) return; - if (this.hasLocalUserMediaVideoTrack) return; - - this.pushNewLocalFeed(stream, SDPStreamMetadataPurpose.Usermedia); - } else if (upgradeAudio) { - if (this.hasLocalUserMediaAudioTrack) return; - - const audioTrack = stream.getAudioTracks()[0]; - this.localUsermediaStream.addTrack(audioTrack); - this.peerConn.addTrack(audioTrack, this.localUsermediaStream); - } else if (upgradeVideo) { - if (this.hasLocalUserMediaVideoTrack) return; - - const videoTrack = stream.getVideoTracks()[0]; - this.localUsermediaStream.addTrack(videoTrack); - this.peerConn.addTrack(videoTrack, this.localUsermediaStream); - } + const getAudio = audio || this.hasLocalUserMediaAudioTrack; + const getVideo = video || this.hasLocalUserMediaVideoTrack; + + // updateLocalUsermediaStream() will take the tracks, use them as + // replacement and throw the stream away, so it isn't reusable + const stream = await this.client.getMediaHandler().getUserMediaStream(getAudio, getVideo, false); + await this.updateLocalUsermediaStream(stream, audio, video); } catch (error) { logger.error("Failed to upgrade the call", error); this.emit(CallEvent.Error, @@ -1071,6 +1056,63 @@ export class MatrixCall extends EventEmitter { } } + /** + * Replaces/adds the tracks from the passed stream to the localUsermediaStream + * @param {MediaStream} stream to use a replacement for the local usermedia stream + */ + public async updateLocalUsermediaStream( + stream: MediaStream, forceAudio = false, forceVideo = false, + ): Promise { + const callFeed = this.localUsermediaFeed; + const audioEnabled = forceAudio || (!callFeed.isAudioMuted() && !this.remoteOnHold); + const videoEnabled = forceVideo || (!callFeed.isVideoMuted() && !this.remoteOnHold); + setTracksEnabled(stream.getAudioTracks(), audioEnabled); + setTracksEnabled(stream.getVideoTracks(), videoEnabled); + + // We want to keep the same stream id, so we replace the tracks rather than the whole stream + for (const track of this.localUsermediaStream.getTracks()) { + this.localUsermediaStream.removeTrack(track); + track.stop(); + } + for (const track of stream.getTracks()) { + this.localUsermediaStream.addTrack(track); + } + + const newSenders = []; + + for (const track of stream.getTracks()) { + const oldSender = this.usermediaSenders.find((sender) => sender.track?.kind === track.kind); + let newSender: RTCRtpSender; + + if (oldSender) { + logger.info( + `Replacing track (` + + `id="${track.id}", ` + + `kind="${track.kind}", ` + + `streamId="${stream.id}", ` + + `streamPurpose="${callFeed.purpose}"` + + `) to peer connection`, + ); + await oldSender.replaceTrack(track); + newSender = oldSender; + } else { + logger.info( + `Adding track (` + + `id="${track.id}", ` + + `kind="${track.kind}", ` + + `streamId="${stream.id}", ` + + `streamPurpose="${callFeed.purpose}"` + + `) to peer connection`, + ); + newSender = this.peerConn.addTrack(track, this.localUsermediaStream); + } + + newSenders.push(newSender); + } + + this.usermediaSenders = newSenders; + } + /** * Set whether our outbound video should be muted or not. * @param {boolean} muted True to mute the outbound video. @@ -1199,8 +1241,8 @@ export class MatrixCall extends EventEmitter { [SDPStreamMetadataKey]: this.getLocalSDPStreamMetadata(), }); - const micShouldBeMuted = this.localUsermediaFeed?.isAudioMuted() || this.remoteOnHold; - const vidShouldBeMuted = this.localUsermediaFeed?.isVideoMuted() || this.remoteOnHold; + const micShouldBeMuted = this.isMicrophoneMuted() || this.remoteOnHold; + const vidShouldBeMuted = this.isLocalVideoMuted() || this.remoteOnHold; setTracksEnabled(this.localUsermediaStream.getAudioTracks(), !micShouldBeMuted); setTracksEnabled(this.localUsermediaStream.getVideoTracks(), !vidShouldBeMuted); diff --git a/src/webrtc/mediaHandler.ts b/src/webrtc/mediaHandler.ts index b1c599d4513..ba84ca899a9 100644 --- a/src/webrtc/mediaHandler.ts +++ b/src/webrtc/mediaHandler.ts @@ -2,7 +2,7 @@ Copyright 2015, 2016 OpenMarket Ltd Copyright 2017 New Vector Ltd Copyright 2019, 2020 The Matrix.org Foundation C.I.C. -Copyright 2021 Šimon Brandner +Copyright 2021 - 2022 Šimon Brandner Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,20 +18,30 @@ limitations under the License. */ import { logger } from "../logger"; +import { MatrixClient } from "../client"; +import { CallState } from "./call"; export class MediaHandler { private audioInput: string; private videoInput: string; - private userMediaStreams: MediaStream[] = []; - private screensharingStreams: MediaStream[] = []; + private localUserMediaStream?: MediaStream; + public userMediaStreams: MediaStream[] = []; + public screensharingStreams: MediaStream[] = []; + + constructor(private client: MatrixClient) { } /** * Set an audio input device to use for MatrixCalls * @param {string} deviceId the identifier for the device * undefined treated as unset */ - public setAudioInput(deviceId: string): void { + public async setAudioInput(deviceId: string): Promise { + logger.info("LOG setting audio input to", deviceId); + + if (this.audioInput === deviceId) return; + this.audioInput = deviceId; + await this.updateLocalUsermediaStreams(); } /** @@ -39,8 +49,39 @@ export class MediaHandler { * @param {string} deviceId the identifier for the device * undefined treated as unset */ - public setVideoInput(deviceId: string): void { + public async setVideoInput(deviceId: string): Promise { + logger.info("LOG setting video input to", deviceId); + + if (this.videoInput === deviceId) return; + this.videoInput = deviceId; + await this.updateLocalUsermediaStreams(); + } + + /** + * Requests new usermedia streams and replace the old ones + */ + public async updateLocalUsermediaStreams(): Promise { + if (this.userMediaStreams.length === 0) return; + + const callMediaStreamParams: Map = new Map(); + for (const call of this.client.callEventHandler.calls.values()) { + callMediaStreamParams.set(call.callId, { + audio: call.hasLocalUserMediaAudioTrack, + video: call.hasLocalUserMediaVideoTrack, + }); + } + + for (const call of this.client.callEventHandler.calls.values()) { + if (call.state === CallState.Ended || !callMediaStreamParams.has(call.callId)) continue; + + const { audio, video } = callMediaStreamParams.get(call.callId); + + // This stream won't be reusable as we will replace the tracks of the old stream + const stream = await this.getUserMediaStream(audio, video, false); + + await call.updateLocalUsermediaStream(stream); + } } public async hasAudioDevice(): Promise { @@ -65,20 +106,44 @@ export class MediaHandler { let stream: MediaStream; - // Find a stream with matching tracks - const matchingStream = this.userMediaStreams.find((stream) => { - if (shouldRequestAudio !== (stream.getAudioTracks().length > 0)) return false; - if (shouldRequestVideo !== (stream.getVideoTracks().length > 0)) return false; - return true; - }); - - if (matchingStream) { - logger.log("Cloning user media stream", matchingStream.id); - stream = matchingStream.clone(); - } else { + if ( + !this.localUserMediaStream || + (this.localUserMediaStream.getAudioTracks().length === 0 && shouldRequestAudio) || + (this.localUserMediaStream.getVideoTracks().length === 0 && shouldRequestVideo) || + (this.localUserMediaStream.getAudioTracks()[0]?.getSettings()?.deviceId !== this.audioInput) || + (this.localUserMediaStream.getVideoTracks()[0]?.getSettings()?.deviceId !== this.videoInput) + ) { const constraints = this.getUserMediaContraints(shouldRequestAudio, shouldRequestVideo); logger.log("Getting user media with constraints", constraints); stream = await navigator.mediaDevices.getUserMedia(constraints); + + for (const track of stream.getTracks()) { + const settings = track.getSettings(); + + if (track.kind === "audio") { + this.audioInput = settings.deviceId; + } else if (track.kind === "video") { + this.videoInput = settings.deviceId; + } + } + + if (reusable) { + this.localUserMediaStream = stream; + } + } else { + stream = this.localUserMediaStream.clone(); + + if (!shouldRequestAudio) { + for (const track of stream.getAudioTracks()) { + stream.removeTrack(track); + } + } + + if (!shouldRequestVideo) { + for (const track of stream.getVideoTracks()) { + stream.removeTrack(track); + } + } } if (reusable) { @@ -103,6 +168,10 @@ export class MediaHandler { logger.debug("Splicing usermedia stream out stream array", mediaStream.id); this.userMediaStreams.splice(index, 1); } + + if (this.localUserMediaStream === mediaStream) { + this.localUserMediaStream = undefined; + } } /** @@ -174,6 +243,7 @@ export class MediaHandler { this.userMediaStreams = []; this.screensharingStreams = []; + this.localUserMediaStream = undefined; } private getUserMediaContraints(audio: boolean, video: boolean): MediaStreamConstraints {