11
11
min_p = 0.0 ,
12
12
# strictly disabled for now
13
13
top_k = 0 ,
14
- # top_p=0 .0,
14
+ top_p = 1 .0 ,
15
15
# frequency_penalties=0.0,
16
16
# presence_penalties=0.0,
17
17
# repetition_penalties=0.0,
@@ -26,11 +26,9 @@ class TPUSupportedSamplingMetadata:
26
26
temperature : torch .Tensor = None
27
27
28
28
min_p : torch .Tensor = None
29
- # Still too slow on forward_native!
30
29
top_k : torch .Tensor = None
31
30
top_p : torch .Tensor = None
32
31
33
- # Greedy sampling flag for compiling single xla graph.
34
32
all_greedy : bool = True
35
33
36
34
# unsupported, you need to return an extra tensor of static size BxV
@@ -103,17 +101,17 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
103
101
DEFAULT_SAMPLING_PARAMS ["min_p" ])
104
102
fill_slice (input_batch .top_k_cpu_tensor ,
105
103
DEFAULT_SAMPLING_PARAMS ["top_k" ])
106
- # TODO Temporarily disabled until sampling options are enabled
107
- # fill_slice(input_batch.top_p_cpu_tensor,
108
- # DEFAULT_SAMPLING_PARAMS["top_p"])
104
+ fill_slice (input_batch .top_p_cpu_tensor ,
105
+ DEFAULT_SAMPLING_PARAMS ["top_p" ])
109
106
110
107
# Slice persistent device tensors to a fixed pre-compiled padded shape.
111
108
return cls (
112
109
temperature = input_batch .temperature_cpu_tensor [:padded_num_reqs ].
113
110
to (xla_device ),
114
111
all_greedy = input_batch .all_greedy ,
115
112
# TODO enable more and avoid returning None values
116
- top_p = None , # input_batch.top_p[:padded_num_reqs],
113
+ top_p = input_batch .top_p_cpu_tensor [:padded_num_reqs ].to (
114
+ xla_device ),
117
115
top_k = input_batch .top_k_cpu_tensor [:padded_num_reqs ].to (
118
116
xla_device ),
119
117
min_p = input_batch .min_p_cpu_tensor [:padded_num_reqs ].to (
0 commit comments