Skip to content

Commit 6570a20

Browse files
committed
token count includes ids
1 parent 0ca814e commit 6570a20

File tree

5 files changed

+26
-9
lines changed

5 files changed

+26
-9
lines changed

Diff for: expose.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ extern "C"
194194
return gpttype_generate(inputs, output);
195195
}
196196

197-
const char* new_token(int idx) {
197+
const char * new_token(int idx) {
198198
if (generated_tokens.size() <= idx || idx < 0) return nullptr;
199199

200200
return generated_tokens[idx].c_str();
@@ -232,9 +232,14 @@ extern "C"
232232
return gpttype_generate_abort();
233233
}
234234

235-
int token_count(const char * input)
235+
static std::vector<int> toks; //just share a static object for token counting
236+
token_count_outputs token_count(const char * input)
236237
{
237238
std::string inputstr = input;
238-
return gpttype_token_count(inputstr);
239+
token_count_outputs output;
240+
toks = gpttype_get_token_arr(inputstr);
241+
output.count = toks.size();
242+
output.ids = toks.data(); //this may be slightly unsafe
243+
return output;
239244
}
240245
}

Diff for: expose.h

+5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ struct generation_outputs
8383
int status = -1;
8484
char text[32768]; //32kb should be enough for any response
8585
};
86+
struct token_count_outputs
87+
{
88+
int count = 0;
89+
int * ids; //we'll just use shared memory for this one, bit of a hack
90+
};
8691

8792
extern std::string executable_path;
8893
extern std::string lora_filename;

Diff for: gpttype_adapter.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,7 @@ bool gpttype_generate_abort()
13901390
return true;
13911391
}
13921392

1393-
int gpttype_token_count(const std::string & input)
1393+
std::vector<int> gpttype_get_token_arr(const std::string & input)
13941394
{
13951395
if(debugmode==1)
13961396
{
@@ -1403,7 +1403,7 @@ int gpttype_token_count(const std::string & input)
14031403
{
14041404
printf("\nTokens Counted: %d\n",tokcount);
14051405
}
1406-
return tokcount;
1406+
return toks;
14071407
}
14081408

14091409
const std::string & gpttype_get_pending_output()

Diff for: koboldcpp.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure):
7777
_fields_ = [("status", ctypes.c_int),
7878
("text", ctypes.c_char * 32768)]
7979

80+
class token_count_outputs(ctypes.Structure):
81+
_fields_ = [("count", ctypes.c_int),
82+
("ids", ctypes.POINTER(ctypes.c_int))]
83+
8084
handle = None
8185

8286
def getdirpath():
@@ -218,7 +222,7 @@ def init_library():
218222
handle.get_total_gens.restype = ctypes.c_int
219223
handle.get_last_stop_reason.restype = ctypes.c_int
220224
handle.abort_generate.restype = ctypes.c_bool
221-
handle.token_count.restype = ctypes.c_int
225+
handle.token_count.restype = token_count_outputs
222226
handle.get_pending_output.restype = ctypes.c_char_p
223227

224228
def load_model(model_filename):
@@ -729,8 +733,11 @@ def do_POST(self):
729733
try:
730734
genparams = json.loads(body)
731735
countprompt = genparams.get('prompt', "")
732-
count = handle.token_count(countprompt.encode("UTF-8"))
733-
response_body = (json.dumps({"value": count}).encode())
736+
rawcountdata = handle.token_count(countprompt.encode("UTF-8"))
737+
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
738+
# the above protects the server in case the count limit got corrupted
739+
countdata = [rawcountdata.ids[i] for i in range(countlimit)]
740+
response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode())
734741

735742
except Exception as e:
736743
utfprint("Count Tokens - Body Error: " + str(e))

Diff for: model_adapter.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
6868
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);
6969
bool gpttype_generate_abort();
7070
const std::string & gpttype_get_pending_output();
71-
int gpttype_token_count(const std::string & input);
71+
std::vector<int> gpttype_get_token_arr(const std::string & input);
7272

7373
void timer_start();
7474
double timer_check();

0 commit comments

Comments
 (0)