20
20
#define COMPRESS_STATS 0
21
21
22
22
#include < stddef.h>
23
+ #include < stdint.h>
23
24
#include < stdio.h>
24
25
25
- #include < array>
26
- #include < cstdio>
27
26
#include < cstring>
28
27
#include < string>
29
28
#include < unordered_map>
35
34
#include " compression/io.h"
36
35
#include " compression/shared.h"
37
36
// IWYU pragma: end_exports
38
- #include " compression/distortion.h"
39
- #include " hwy/aligned_allocator.h"
40
- #include " hwy/base.h" // BF16
41
- #include " hwy/contrib/thread_pool/thread_pool.h"
37
+ #include " util/allocator.h"
42
38
#if COMPRESS_STATS
39
+ #include " compression/distortion.h"
43
40
#include " hwy/stats.h"
44
41
#endif
45
42
46
43
namespace gcpp {
47
44
48
- // Compressed representation of floating-point elements. The array length may
49
- // differ from the number of elements. Associated operations such as Dot are
50
- // implemented in SIMD code and are thus non-member functions.
51
- template <typename Packed, size_t kCapacity >
52
- class CompressedArray {
53
- public:
54
- using value_type = Packed;
55
-
56
- // Note that whenever you access data(), you have to consider a scale() that
57
- // may be different from 1.0f.
58
- Packed* data () { return data_.data (); }
59
- const Packed* data () const { return data_.data (); }
60
- // The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
61
- // calling it means "I am sure the scale is 1 and therefore ignore the scale".
62
- // A scale of 0 indicates that the scale has likely never been set, so is
63
- // "implicitly 1".
64
- const Packed* data_scale1 () const {
65
- HWY_ASSERT (scale () == 1 .f || scale () == 0 .f );
66
- return data_.data ();
67
- }
68
-
69
- // Decoded elements should be multiplied by this to restore their original
70
- // range. This is required because SfpStream can only encode a limited range
71
- // of magnitudes.
72
- float scale () const { return scale_[0 ]; }
73
- void set_scale (float scale) { scale_[0 ] = scale; }
74
-
75
- constexpr size_t NumElements () const { return kCapacity ; }
76
-
77
- // Returns total number of packed elements for `BlobReader::Enqueue` and
78
- // `Compress`. This differs from `NumElements` for `Packed=NuqStream`.
79
- PackedSpan<Packed> GetSpan () { return MakeSpan (data (), data_.size ()); }
80
- PackedSpan<const Packed> GetSpan () const {
81
- return MakeSpan (data (), data_.size ());
82
- }
83
-
84
- private:
85
- std::array<Packed, CompressedArrayElements<Packed>(kCapacity )> data_;
86
- // Blobs are at least kBlobAlign bytes anyway.
87
- float scale_[kBlobAlign / sizeof (float )];
88
- };
89
-
90
- // Yet another array class. This one is intended to be compatible with
91
- // CompressedArray, but have both run-time sizing and compile-time constant
92
- // size.
93
- // It also provides easy conversion from/to a table of contents for a BlobStore
94
- // file, and a templated (compile-time) accessor for a 2-d array of fixed inner
95
- // dimension and type.
96
- // The base class is intended for accessing the metadata, without needing to
97
- // know any of the template arguments.
98
- // It holds only a borrowed pointer to the data, but all metadata.
45
+ // Base class for rank-1 or 2 tensors (vector or matrix).
46
+ // Supports both dynamic and compile-time sizing.
47
+ // Holds metadata and a non-owning pointer to the data, owned by the derived
48
+ // MatStorageT class.
49
+ // This class also provides easy conversion from/to a table of contents for a
50
+ // BlobStore file, and a templated (compile-time) accessor for a 2-d array of
51
+ // fixed inner dimension and type.
99
52
// It is designed to be put in a vector, and has default copy and operator=, so
100
53
// it is easy to read/write a blob_store file.
101
- // The derived class or an external class owns the data.
102
54
class MatPtr {
103
55
public:
104
56
// Full constructor for dynamic sizing.
@@ -111,12 +63,12 @@ class MatPtr {
111
63
rows_(rows),
112
64
cols_(cols),
113
65
ptr_(nullptr ) {}
114
- // Default constructor doesn't set anything .
66
+ // Default is to leave all fields default-initialized .
115
67
MatPtr () = default;
116
68
virtual ~MatPtr ();
117
69
118
70
// Number of hwy::uint128_t in a TOC entry.
119
- // Note that the old-style BlobStore files Only have a list of keys and size.
71
+ // Note that the old-style BlobStore files only have a list of keys and size.
120
72
// The new-style BlobStore files have an entry called "toc" that contains a
121
73
// vector of 4-tuples of
122
74
// (name, type, (num_elements, element_size), (rows, cols)).
@@ -144,6 +96,7 @@ class MatPtr {
144
96
}
145
97
146
98
// Compatibility interface for CompressedArray.
99
+ // TODO: remove.
147
100
template <typename T>
148
101
T* data () {
149
102
return HWY_RCAST_ALIGNED (T*, ptr_);
@@ -177,7 +130,6 @@ class MatPtr {
177
130
178
131
// Returns the number of bytes in the array.
179
132
size_t SizeBytes () const { return num_elements_ * element_size_; }
180
- size_t CompressedSize () const { return SizeBytes (); }
181
133
182
134
// Returns the number of rows in the 2-d array (outer dimension).
183
135
size_t Rows () const { return rows_; }
@@ -211,8 +163,8 @@ class MatPtr {
211
163
}
212
164
213
165
// Calls func on the upcasted type. Since MatPtr by design is not templated,
214
- // here we provide a way to get to the derived type, provided that the type
215
- // matches one of a known short-list .
166
+ // here we provide a way to get to the derived type, provided that `Type()`
167
+ // is one of the strings returned by `TypeName()` .
216
168
template <class FuncT , typename ... TArgs>
217
169
decltype (auto ) CallUpcasted(FuncT& func, TArgs&&... args);
218
170
@@ -243,8 +195,6 @@ class MatPtr {
243
195
template <typename MatT>
244
196
class MatPtrT : public MatPtr {
245
197
public:
246
- using value_type = MatT;
247
-
248
198
// Full constructor for dynamic sizing.
249
199
MatPtrT (const std::string& name, size_t rows, size_t cols)
250
200
: MatPtr(name, TypeEnum<MatT>(), sizeof (MatT), rows, cols) {}
@@ -276,20 +226,13 @@ class MatPtrT : public MatPtr {
276
226
}
277
227
return name;
278
228
}
229
+
279
230
// Sets the number of elements in the array. For use when the number of
280
231
// elements is != rows * cols ONLY.
281
232
void SetNumElements (size_t num_elements) {
282
233
num_elements_ = CompressedArrayElements<MatT>(num_elements);
283
234
}
284
235
285
- // Fast 2-d accessor for a 2-d array of fixed inner dimension and type.
286
- template <typename T = MatT, size_t kInner >
287
- const T& AtT (size_t row, size_t col) const {
288
- size_t index = row * kInner + col;
289
- HWY_DASSERT (index < num_elements_);
290
- return HWY_RCAST_ALIGNED (const T*, ptr_)[index ];
291
- }
292
-
293
236
// 2-d Accessor for a specific type but with a dynamic inner dimension.
294
237
template <typename T = MatT>
295
238
const T& At (size_t row, size_t col) const {
@@ -299,17 +242,15 @@ class MatPtrT : public MatPtr {
299
242
}
300
243
301
244
// 1-d Accessor for a specific type.
302
- template < typename T = MatT>
303
- const T & At (size_t index) const {
245
+ // TODO: replace this with a Foreach(), or at least a ForEachRow().
246
+ const MatT & At (size_t index) const {
304
247
HWY_DASSERT (index < num_elements_);
305
- return HWY_RCAST_ALIGNED (const T*, ptr_)[index ];
306
- }
307
- template <typename T = MatT>
308
- T& At (size_t index) {
309
- return HWY_RCAST_ALIGNED (T*, ptr_)[index ];
248
+ return HWY_RCAST_ALIGNED (const MatT*, ptr_)[index ];
310
249
}
250
+ MatT& At (size_t index) { return HWY_RCAST_ALIGNED (MatT*, ptr_)[index ]; }
311
251
312
252
// Compatibility interface for CompressedArray.
253
+ // TODO: remove
313
254
template <typename T = MatT>
314
255
T* data () {
315
256
return HWY_RCAST_ALIGNED (T*, ptr_);
@@ -353,15 +294,14 @@ class MatStorageT : public MatPtrT<MatT> {
353
294
public:
354
295
// Full constructor for dynamic sizing.
355
296
MatStorageT (const std::string& name, size_t rows, size_t cols)
356
- : MatPtrT<MatT>(name, rows, cols),
357
- data_ (hwy::AllocateAligned<MatT>(
358
- hwy::DivCeil (this ->SizeBytes (), sizeof(MatT)))) {
359
- this ->ptr_ = data_.get ();
297
+ : MatPtrT<MatT>(name, rows, cols) {
298
+ Allocate ();
360
299
}
361
300
// Can copy the metadata, from a MatPtr, and allocate later.
362
301
MatStorageT (const MatPtr& other) : MatPtrT<MatT>(other) {}
302
+ ~MatStorageT () = default ;
363
303
364
- // No copying of MatStorageT as it contains big data .
304
+ // Move-only because this contains a unique_ptr .
365
305
MatStorageT (const MatStorageT& other) = delete ;
366
306
MatStorageT& operator =(const MatStorageT& other) = delete ;
367
307
MatStorageT (MatStorageT&& other) = default ;
@@ -377,7 +317,7 @@ class MatStorageT : public MatPtrT<MatT> {
377
317
} else {
378
318
this ->num_elements_ = num_elements;
379
319
}
380
- data_ = hwy::AllocateAligned <MatT>(num_elements);
320
+ data_ = Allocator::Alloc <MatT>(num_elements);
381
321
this ->ptr_ = data_.get ();
382
322
}
383
323
@@ -388,8 +328,6 @@ class MatStorageT : public MatPtrT<MatT> {
388
328
}
389
329
390
330
private:
391
- // Aligned data array.
392
- // std::unique_ptr<MatT[]> data_;
393
331
hwy::AlignedFreeUniquePtr<MatT[]> data_;
394
332
};
395
333
@@ -507,7 +445,7 @@ class CompressStats {
507
445
};
508
446
#else
509
447
struct CompressStats {
510
- void Notify (const DistortionStats& ) {}
448
+ void Notify (... ) {}
511
449
void NotifyIn (int ) {}
512
450
void Assimilate (const CompressStats&) {}
513
451
void PrintAll () {}
@@ -526,18 +464,17 @@ struct CompressWorkingSet {
526
464
527
465
// Functor called for each tensor, which loads them and their scaling factors
528
466
// from BlobStore.
529
- class CacheLoader {
467
+ class ReadFromBlobStore {
530
468
public:
531
- explicit CacheLoader (const Path& blob_filename) {
469
+ explicit ReadFromBlobStore (const Path& blob_filename) {
532
470
err_ = reader_.Open (blob_filename);
533
- if (err_ != 0 ) {
534
- fprintf (stderr,
535
- " Cached compressed weights does not exist yet (code %d), "
536
- " loading from file: %s.\n " ,
537
- err_, blob_filename.path .c_str ());
471
+ if (HWY_UNLIKELY (err_ != 0 )) {
472
+ fprintf (stderr, " Error %d opening BlobStore %s.\n " , err_,
473
+ blob_filename.path .c_str ());
474
+ return ; // avoid overwriting err_ to ensure ReadAll will fail.
538
475
}
539
476
err_ = file_toc_.LoadToc (reader_);
540
- if (err_ != 0 ) {
477
+ if (HWY_UNLIKELY ( err_ != 0 ) ) {
541
478
fprintf (stderr, " Found a TOC, but failed to load it (code %d)\n " , err_);
542
479
}
543
480
}
0 commit comments