Skip to content

Commit 4e565c7

Browse files
committed
download: add support for downloading to file
1 parent d910a51 commit 4e565c7

File tree

3 files changed

+275
-16
lines changed

3 files changed

+275
-16
lines changed

download-to-file.go

+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
// Copyright (c) 2024 Tulir Asokan
2+
//
3+
// This Source Code Form is subject to the terms of the Mozilla Public
4+
// License, v. 2.0. If a copy of the MPL was not distributed with this
5+
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
7+
package whatsmeow
8+
9+
import (
10+
"crypto/hmac"
11+
"crypto/sha256"
12+
"encoding/base64"
13+
"errors"
14+
"fmt"
15+
"io"
16+
"os"
17+
"strings"
18+
"time"
19+
20+
"go.mau.fi/util/fallocate"
21+
"go.mau.fi/util/retryafter"
22+
23+
"go.mau.fi/whatsmeow/proto/waMediaTransport"
24+
"go.mau.fi/whatsmeow/util/cbcutil"
25+
)
26+
27+
type File interface {
28+
io.Reader
29+
io.Writer
30+
io.Seeker
31+
io.ReaderAt
32+
io.WriterAt
33+
Truncate(size int64) error
34+
Stat() (os.FileInfo, error)
35+
}
36+
37+
// DownloadToFile downloads the attachment from the given protobuf message.
38+
//
39+
// This is otherwise identical to [Download], but writes the attachment to a file instead of returning it as a byte slice.
40+
func (cli *Client) DownloadToFile(msg DownloadableMessage, file File) error {
41+
mediaType, ok := classToMediaType[msg.ProtoReflect().Descriptor().Name()]
42+
if !ok {
43+
return fmt.Errorf("%w '%s'", ErrUnknownMediaType, string(msg.ProtoReflect().Descriptor().Name()))
44+
}
45+
urlable, ok := msg.(downloadableMessageWithURL)
46+
var url string
47+
var isWebWhatsappNetURL bool
48+
if ok {
49+
url = urlable.GetUrl()
50+
isWebWhatsappNetURL = strings.HasPrefix(url, "https://web.whatsapp.net")
51+
}
52+
if len(url) > 0 && !isWebWhatsappNetURL {
53+
return cli.downloadAndDecryptToFile(url, msg.GetMediaKey(), mediaType, getSize(msg), msg.GetFileEncSHA256(), msg.GetFileSHA256(), file)
54+
} else if len(msg.GetDirectPath()) > 0 {
55+
return cli.DownloadMediaWithPathToFile(msg.GetDirectPath(), msg.GetFileEncSHA256(), msg.GetFileSHA256(), msg.GetMediaKey(), getSize(msg), mediaType, mediaTypeToMMSType[mediaType], file)
56+
} else {
57+
if isWebWhatsappNetURL {
58+
cli.Log.Warnf("Got a media message with a web.whatsapp.net URL (%s) and no direct path", url)
59+
}
60+
return ErrNoURLPresent
61+
}
62+
}
63+
64+
func (cli *Client) DownloadFBToFile(transport *waMediaTransport.WAMediaTransport_Integral, mediaType MediaType, file File) error {
65+
return cli.DownloadMediaWithPathToFile(transport.GetDirectPath(), transport.GetFileEncSHA256(), transport.GetFileSHA256(), transport.GetMediaKey(), -1, mediaType, mediaTypeToMMSType[mediaType], file)
66+
}
67+
68+
func (cli *Client) DownloadMediaWithPathToFile(directPath string, encFileHash, fileHash, mediaKey []byte, fileLength int, mediaType MediaType, mmsType string, file File) error {
69+
mediaConn, err := cli.refreshMediaConn(false)
70+
if err != nil {
71+
return fmt.Errorf("failed to refresh media connections: %w", err)
72+
}
73+
if len(mmsType) == 0 {
74+
mmsType = mediaTypeToMMSType[mediaType]
75+
}
76+
for i, host := range mediaConn.Hosts {
77+
// TODO omit hash for unencrypted media?
78+
mediaURL := fmt.Sprintf("https://%s%s&hash=%s&mms-type=%s&__wa-mms=", host.Hostname, directPath, base64.URLEncoding.EncodeToString(encFileHash), mmsType)
79+
err = cli.downloadAndDecryptToFile(mediaURL, mediaKey, mediaType, fileLength, encFileHash, fileHash, file)
80+
if err == nil || errors.Is(err, ErrFileLengthMismatch) || errors.Is(err, ErrInvalidMediaSHA256) {
81+
return err
82+
} else if i >= len(mediaConn.Hosts)-1 {
83+
return fmt.Errorf("failed to download media from last host: %w", err)
84+
}
85+
// TODO there are probably some errors that shouldn't retry
86+
cli.Log.Warnf("Failed to download media: %s, trying with next host...", err)
87+
}
88+
return err
89+
}
90+
91+
func (cli *Client) downloadAndDecryptToFile(url string, mediaKey []byte, appInfo MediaType, fileLength int, fileEncSHA256, fileSHA256 []byte, file File) error {
92+
iv, cipherKey, macKey, _ := getMediaKeys(mediaKey, appInfo)
93+
hasher := sha256.New()
94+
if mac, err := cli.downloadPossiblyEncryptedMediaWithRetriesToFile(url, fileEncSHA256, file); err != nil {
95+
return err
96+
} else if mediaKey == nil && fileEncSHA256 == nil && mac == nil {
97+
// Unencrypted media, just return the downloaded data
98+
return nil
99+
} else if err = validateMediaFile(file, iv, macKey, mac); err != nil {
100+
return err
101+
} else if _, err = file.Seek(0, io.SeekStart); err != nil {
102+
return fmt.Errorf("failed to seek to start of file after validating mac: %w", err)
103+
} else if err = cbcutil.DecryptFile(cipherKey, iv, file); err != nil {
104+
return fmt.Errorf("failed to decrypt file: %w", err)
105+
} else if info, err := file.Stat(); err != nil {
106+
return fmt.Errorf("failed to stat file: %w", err)
107+
} else if info.Size() != int64(fileLength) {
108+
return fmt.Errorf("%w: expected %d, got %d", ErrFileLengthMismatch, fileLength, info.Size())
109+
} else if _, err = file.Seek(0, io.SeekStart); err != nil {
110+
return fmt.Errorf("failed to seek to start of file after decrypting: %w", err)
111+
} else if _, err = io.Copy(hasher, file); err != nil {
112+
return fmt.Errorf("failed to hash file: %w", err)
113+
} else if !hmac.Equal(fileSHA256, hasher.Sum(nil)) {
114+
return ErrInvalidMediaSHA256
115+
}
116+
return nil
117+
}
118+
119+
func (cli *Client) downloadPossiblyEncryptedMediaWithRetriesToFile(url string, checksum []byte, file File) (mac []byte, err error) {
120+
for retryNum := 0; retryNum < 5; retryNum++ {
121+
if checksum == nil {
122+
_, _, err = cli.downloadMediaToFile(url, file)
123+
} else {
124+
mac, err = cli.downloadEncryptedMediaToFile(url, checksum, file)
125+
}
126+
if err == nil || !shouldRetryMediaDownload(err) {
127+
return
128+
}
129+
retryDuration := time.Duration(retryNum+1) * time.Second
130+
var httpErr DownloadHTTPError
131+
if errors.As(err, &httpErr) {
132+
retryDuration = retryafter.Parse(httpErr.Response.Header.Get("Retry-After"), retryDuration)
133+
}
134+
cli.Log.Warnf("Failed to download media due to network error: %v, retrying in %s...", err, retryDuration)
135+
time.Sleep(retryDuration)
136+
}
137+
return
138+
}
139+
140+
func (cli *Client) downloadMediaToFile(url string, file io.Writer) (int64, []byte, error) {
141+
resp, err := cli.doMediaDownloadRequest(url)
142+
if err != nil {
143+
return 0, nil, err
144+
}
145+
defer resp.Body.Close()
146+
osFile, ok := file.(*os.File)
147+
if ok && resp.ContentLength > 0 {
148+
err = fallocate.Fallocate(osFile, int(resp.ContentLength))
149+
if err != nil {
150+
return 0, nil, fmt.Errorf("failed to preallocate file: %w", err)
151+
}
152+
}
153+
hasher := sha256.New()
154+
n, err := io.Copy(file, io.TeeReader(resp.Body, hasher))
155+
return n, hasher.Sum(nil), err
156+
}
157+
158+
func (cli *Client) downloadEncryptedMediaToFile(url string, checksum []byte, file File) ([]byte, error) {
159+
size, hash, err := cli.downloadMediaToFile(url, file)
160+
if err != nil {
161+
return nil, err
162+
} else if size <= mediaHMACLength {
163+
return nil, ErrTooShortFile
164+
} else if len(checksum) == 32 && !hmac.Equal(checksum, hash) {
165+
return nil, ErrInvalidMediaEncSHA256
166+
}
167+
mac := make([]byte, mediaHMACLength)
168+
_, err = file.ReadAt(mac, size-mediaHMACLength)
169+
if err != nil {
170+
return nil, fmt.Errorf("failed to read MAC from file: %w", err)
171+
}
172+
err = file.Truncate(size - mediaHMACLength)
173+
if err != nil {
174+
return nil, fmt.Errorf("failed to truncate file to remove MAC: %w", err)
175+
}
176+
return mac, nil
177+
}
178+
179+
func validateMediaFile(file io.ReadSeeker, iv, macKey, mac []byte) error {
180+
h := hmac.New(sha256.New, macKey)
181+
h.Write(iv)
182+
_, err := file.Seek(0, io.SeekStart)
183+
if err != nil {
184+
return fmt.Errorf("failed to seek to start of file: %w", err)
185+
}
186+
_, err = io.Copy(h, file)
187+
if err != nil {
188+
return fmt.Errorf("failed to hash file: %w", err)
189+
}
190+
if !hmac.Equal(h.Sum(nil)[:mediaHMACLength], mac) {
191+
return ErrInvalidMediaHMAC
192+
}
193+
return nil
194+
}

download.go

+19-7
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,13 @@ func (cli *Client) downloadPossiblyEncryptedMediaWithRetries(url string, checksu
288288
if errors.As(err, &httpErr) {
289289
retryDuration = retryafter.Parse(httpErr.Response.Header.Get("Retry-After"), retryDuration)
290290
}
291-
cli.Log.Warnf("Failed to download media due to network error: %w, retrying in %s...", err, retryDuration)
291+
cli.Log.Warnf("Failed to download media due to network error: %v, retrying in %s...", err, retryDuration)
292292
time.Sleep(retryDuration)
293293
}
294294
return
295295
}
296296

297-
func (cli *Client) downloadMedia(url string) ([]byte, error) {
297+
func (cli *Client) doMediaDownloadRequest(url string) (*http.Response, error) {
298298
req, err := http.NewRequest(http.MethodGet, url, nil)
299299
if err != nil {
300300
return nil, fmt.Errorf("failed to prepare request: %w", err)
@@ -309,22 +309,34 @@ func (cli *Client) downloadMedia(url string) ([]byte, error) {
309309
if err != nil {
310310
return nil, err
311311
}
312-
defer resp.Body.Close()
313312
if resp.StatusCode != http.StatusOK {
313+
_ = resp.Body.Close()
314314
return nil, DownloadHTTPError{Response: resp}
315315
}
316-
return io.ReadAll(resp.Body)
316+
return resp, nil
317+
}
318+
319+
func (cli *Client) downloadMedia(url string) ([]byte, error) {
320+
resp, err := cli.doMediaDownloadRequest(url)
321+
if err != nil {
322+
return nil, err
323+
}
324+
data, err := io.ReadAll(resp.Body)
325+
_ = resp.Body.Close()
326+
return data, err
317327
}
318328

329+
const mediaHMACLength = 10
330+
319331
func (cli *Client) downloadEncryptedMedia(url string, checksum []byte) (file, mac []byte, err error) {
320332
data, err := cli.downloadMedia(url)
321333
if err != nil {
322334
return
323-
} else if len(data) <= 10 {
335+
} else if len(data) <= mediaHMACLength {
324336
err = ErrTooShortFile
325337
return
326338
}
327-
file, mac = data[:len(data)-10], data[len(data)-10:]
339+
file, mac = data[:len(data)-mediaHMACLength], data[len(data)-mediaHMACLength:]
328340
if len(checksum) == 32 && sha256.Sum256(data) != *(*[32]byte)(checksum) {
329341
err = ErrInvalidMediaEncSHA256
330342
}
@@ -335,7 +347,7 @@ func validateMedia(iv, file, macKey, mac []byte) error {
335347
h := hmac.New(sha256.New, macKey)
336348
h.Write(iv)
337349
h.Write(file)
338-
if !hmac.Equal(h.Sum(nil)[:10], mac) {
350+
if !hmac.Equal(h.Sum(nil)[:mediaHMACLength], mac) {
339351
return ErrInvalidMediaHMAC
340352
}
341353
return nil

util/cbcutil/cbc.go

+62-9
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,86 @@ import (
2424
"errors"
2525
"fmt"
2626
"io"
27+
"os"
2728
)
2829

2930
/*
3031
Decrypt is a function that decrypts a given cipher text with a provided key and initialization vector(iv).
3132
*/
3233
func Decrypt(key, iv, ciphertext []byte) ([]byte, error) {
3334
block, err := aes.NewCipher(key)
34-
3535
if err != nil {
3636
return nil, err
37-
}
38-
39-
if len(ciphertext) < aes.BlockSize {
37+
} else if len(ciphertext) < aes.BlockSize {
4038
return nil, fmt.Errorf("ciphertext is shorter then block size: %d / %d", len(ciphertext), aes.BlockSize)
4139
}
4240

43-
if iv == nil {
44-
iv = ciphertext[:aes.BlockSize]
45-
ciphertext = ciphertext[aes.BlockSize:]
46-
}
47-
4841
cbc := cipher.NewCBCDecrypter(block, iv)
4942
cbc.CryptBlocks(ciphertext, ciphertext)
5043

5144
return unpad(ciphertext)
5245
}
5346

47+
type File interface {
48+
io.Reader
49+
io.WriterAt
50+
Truncate(size int64) error
51+
Stat() (os.FileInfo, error)
52+
}
53+
54+
func DecryptFile(key, iv []byte, file File) error {
55+
block, err := aes.NewCipher(key)
56+
if err != nil {
57+
return err
58+
}
59+
cbc := cipher.NewCBCDecrypter(block, iv)
60+
stat, err := file.Stat()
61+
if err != nil {
62+
return fmt.Errorf("failed to stat file: %w", err)
63+
}
64+
fileSize := stat.Size()
65+
if fileSize%aes.BlockSize != 0 {
66+
return fmt.Errorf("file size is not a multiple of the block size: %d / %d", fileSize, aes.BlockSize)
67+
}
68+
69+
var bufSize int64 = 32 * 1024
70+
if fileSize < bufSize {
71+
bufSize = fileSize
72+
}
73+
buf := make([]byte, bufSize)
74+
var writePtr int64
75+
var lastByte byte
76+
for writePtr < fileSize {
77+
if writePtr+bufSize > fileSize {
78+
buf = buf[:fileSize-writePtr]
79+
}
80+
var n int
81+
n, err = io.ReadFull(file, buf)
82+
if err != nil {
83+
return fmt.Errorf("failed to read file: %w", err)
84+
} else if n != len(buf) {
85+
return fmt.Errorf("failed to read full buffer: %d / %d", n, len(buf))
86+
}
87+
cbc.CryptBlocks(buf, buf)
88+
n, err = file.WriteAt(buf, writePtr)
89+
if err != nil {
90+
return fmt.Errorf("failed to write file: %w", err)
91+
} else if n != len(buf) {
92+
return fmt.Errorf("failed to write full buffer: %d / %d", n, len(buf))
93+
}
94+
writePtr += int64(len(buf))
95+
lastByte = buf[len(buf)-1]
96+
}
97+
if int64(lastByte) > fileSize {
98+
return fmt.Errorf("padding is greater then the length: %d / %d", lastByte, fileSize)
99+
}
100+
err = file.Truncate(fileSize - int64(lastByte))
101+
if err != nil {
102+
return fmt.Errorf("failed to truncate file to remove padding: %w", err)
103+
}
104+
return nil
105+
}
106+
54107
/*
55108
Encrypt is a function that encrypts plaintext with a given key and an optional initialization vector(iv).
56109
*/

0 commit comments

Comments
 (0)