@@ -646,18 +646,17 @@ def get_type(val):
646
646
sys .exit ()
647
647
648
648
649
+ class WriterState (Enum ):
650
+ EMPTY = auto ()
651
+ HEADER = auto ()
652
+ KV_DATA = auto ()
653
+ TI_DATA = auto ()
654
+
655
+
649
656
class GGUFWriter :
650
657
fout : BufferedWriter
651
- arch : str
652
- offset_tensor = 0
653
- data_alignment = GGUF_DEFAULT_ALIGNMENT
654
- kv_data = b""
655
- kv_data_count = 0
656
- ti_data = b""
657
- ti_data_count = 0
658
- use_temp_file : bool
659
- temp_file : tempfile .SpooledTemporaryFile [bytes ] | None = None
660
- tensors : list [tuple [np .ndarray [Any , Any ], int ]]
658
+ temp_file : tempfile .SpooledTemporaryFile [bytes ] | None
659
+ tensors : list [np .ndarray [Any , Any ]]
661
660
662
661
@property
663
662
def pack_prefix (self ):
@@ -683,27 +682,47 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True
683
682
GGUFValueType .FLOAT64 : f"{ self .pack_prefix } d" ,
684
683
GGUFValueType .BOOL : "?" ,
685
684
}
686
- self .add_architecture ()
685
+ self .offset_tensor = 0
686
+ self .data_alignment = GGUF_DEFAULT_ALIGNMENT
687
+ self .kv_data = b""
688
+ self .kv_data_count = 0
689
+ self .ti_data = b""
690
+ self .ti_data_count = 0
687
691
self .use_temp_file = use_temp_file
692
+ self .temp_file = None
688
693
self .tensors = []
689
694
endianess_str = "Big Endian" if self .endianess == GGUFEndian .BIG else "Little Endian"
690
695
print (f"This gguf file is for { endianess_str } only" )
696
+ self .state = WriterState .EMPTY
697
+
698
+ self .add_architecture ()
691
699
692
700
def write_header_to_file (self ):
701
+ if self .state is not WriterState .EMPTY :
702
+ raise ValueError (f'Expected output file to be empty, got { self .state } ' )
703
+
693
704
self .fout .write (struct .pack ("<I" , GGUF_MAGIC ))
694
705
self .fout .write (struct .pack (f"{ self .pack_prefix } I" , GGUF_VERSION ))
695
706
self .fout .write (struct .pack (f"{ self .pack_prefix } Q" , self .ti_data_count ))
696
707
self .fout .write (struct .pack (f"{ self .pack_prefix } Q" , self .kv_data_count ))
697
708
self .flush ()
698
- # print("tensors " + str( self.ti_data_count) + " kv " + str(self.kv_data_count))
709
+ self .state = WriterState . HEADER
699
710
700
711
def write_kv_data_to_file (self ):
712
+ if self .state is not WriterState .HEADER :
713
+ raise ValueError (f'Expected output file to contain the header, got { self .state } ' )
714
+
701
715
self .fout .write (self .kv_data )
702
716
self .flush ()
717
+ self .state = WriterState .KV_DATA
703
718
704
719
def write_ti_data_to_file (self ):
720
+ if self .state is not WriterState .KV_DATA :
721
+ raise ValueError (f'Expected output file to contain KV data, got { self .state } ' )
722
+
705
723
self .fout .write (self .ti_data )
706
724
self .flush ()
725
+ self .state = WriterState .TI_DATA
707
726
708
727
def add_key (self , key : str ):
709
728
self .add_val (key , GGUFValueType .STRING , add_vtype = False )
@@ -796,6 +815,9 @@ def ggml_pad(x: int, n: int) -> int:
796
815
return ((x + n - 1 ) // n ) * n
797
816
798
817
def add_tensor_info (self , name : str , tensor_shape : Sequence [int ], tensor_dtype : np .dtype [np .float16 ] | np .dtype [np .float32 ], tensor_nbytes : int , raw_dtype : GGMLQuantizationType | None = None ):
818
+ if self .state is not WriterState .EMPTY :
819
+ raise ValueError (f'Expected output file to be empty, got { self .state } ' )
820
+
799
821
assert raw_dtype is not None or tensor_dtype in (np .float32 , np .float16 ), "Only F32 and F16 tensors are supported for now"
800
822
801
823
encoded_name = name .encode ("utf8" )
@@ -825,23 +847,22 @@ def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequenc
825
847
shape : Sequence [int ] = raw_shape if raw_shape is not None else tensor .shape
826
848
self .add_tensor_info (name , shape , tensor .dtype , tensor .nbytes , raw_dtype = raw_dtype )
827
849
828
- pad = GGUFWriter .ggml_pad (tensor .nbytes , self .data_alignment ) - tensor .nbytes
829
-
830
- if self .temp_file is None :
831
- self .tensors .append ((tensor , pad ))
850
+ if self .temp_file is None :
851
+ self .tensors .append (tensor )
832
852
return
833
853
834
854
tensor .tofile (self .temp_file )
855
+ self .write_padding (self .temp_file , tensor .nbytes )
835
856
836
- if pad != 0 :
837
- self .temp_file .write (bytes ([0 ] * pad ))
838
-
839
- def write_padding (self , fp : BinaryIO , n : int , align : int | None = None ):
857
+ def write_padding (self , fp : IO [bytes ], n : int , align : int | None = None ):
840
858
pad = GGUFWriter .ggml_pad (n , align if align is not None else self .data_alignment ) - n
841
859
if pad != 0 :
842
860
fp .write (bytes ([0 ] * pad ))
843
861
844
862
def write_tensor_data (self , tensor : np .ndarray [Any , Any ]):
863
+ if self .state is not WriterState .TI_DATA :
864
+ raise ValueError (f'Expected output file to contain tensor info, got { self .state } ' )
865
+
845
866
if self .endianess == GGUFEndian .BIG :
846
867
tensor .byteswap (inplace = True )
847
868
self .write_padding (self .fout , self .fout .tell ())
@@ -854,10 +875,13 @@ def write_tensors_to_file(self):
854
875
self .write_padding (self .fout , self .fout .tell ())
855
876
856
877
if self .temp_file is None :
857
- for (currtensor , currpad ) in self .tensors :
858
- currtensor .tofile (self .fout )
859
- if currpad != 0 :
860
- self .fout .write (bytes ([0 ] * currpad ))
878
+ while True :
879
+ try :
880
+ tensor = self .tensors .pop (0 )
881
+ except IndexError :
882
+ break
883
+ tensor .tofile (self .fout )
884
+ self .write_padding (self .fout , tensor .nbytes )
861
885
return
862
886
863
887
self .temp_file .seek (0 )
@@ -1002,11 +1026,8 @@ def add_pad_token_id(self, id: int):
1002
1026
1003
1027
1004
1028
class SpecialVocab :
1005
- load_merges : bool = False
1006
- merges : list [str ] = []
1007
- special_token_types : tuple [str , ...] = ('bos' , 'eos' , 'unk' , 'sep' , 'pad' )
1008
- special_token_ids : dict [str , int ] = {}
1009
- n_vocab : int | None = None
1029
+ merges : list [str ]
1030
+ special_token_ids : dict [str , int ]
1010
1031
1011
1032
def __init__ (
1012
1033
self , path : str | os .PathLike [str ], load_merges : bool = False ,
@@ -1016,8 +1037,11 @@ def __init__(
1016
1037
self .special_token_ids = {}
1017
1038
self .n_vocab = n_vocab
1018
1039
self .load_merges = load_merges
1040
+ self .merges = []
1019
1041
if special_token_types is not None :
1020
1042
self .special_token_types = special_token_types
1043
+ else :
1044
+ self .special_token_types = ('bos' , 'eos' , 'unk' , 'sep' , 'pad' )
1021
1045
self ._load (Path (path ))
1022
1046
1023
1047
def _load (self , path : Path ) -> None :
0 commit comments