Skip to content

Commit b57af4c

Browse files
authored
Remove scipy by implementing log_softmax (#1561)
1 parent baa6560 commit b57af4c

File tree

6 files changed

+39
-11
lines changed

6 files changed

+39
-11
lines changed

setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,7 @@ def _parse_requirements_file(file_path):
165165
_haystack_integration_deps = _parse_requirements_file(_haystack_requirements_file_path)
166166
_clip_deps = [
167167
"open_clip_torch==2.20.0",
168-
"scipy<1.10,>=1.8",
169-
"transformers<4.35",
168+
"transformers<4.37",
170169
]
171170

172171

src/deepsparse/clip/zeroshot_pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from deepsparse.clip import CLIPTextInput, CLIPVisualInput
2222
from deepsparse.legacy.pipeline import BasePipeline, Pipeline
23-
from scipy.special import softmax
23+
from deepsparse.utils import numpy_softmax
2424

2525

2626
__all__ = ["CLIPZeroShotInput", "CLIPZeroShotOutput", "CLIPZeroShotPipeline"]
@@ -103,7 +103,7 @@ def __call__(self, *args, **kwargs):
103103
text_output /= lingalg.norm(text_output, axis=-1, keepdims=True)
104104

105105
output_product = 100.0 * visual_output @ text_output.T
106-
text_probs = softmax(output_product, axis=-1)
106+
text_probs = numpy_softmax(output_product, axis=-1)
107107

108108
return self.output_schema(text_scores=np.vsplit(text_probs, len(text_probs)))
109109

src/deepsparse/server/openai_server.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,9 @@
4646
)
4747
from deepsparse.server.server import Server
4848
from deepsparse.tasks import SupportedTasks
49-
from deepsparse.utils import InferenceState
49+
from deepsparse.utils import InferenceState, numpy_softmax
5050
from fastapi import BackgroundTasks, FastAPI, Request
5151
from fastapi.responses import StreamingResponse
52-
from scipy.special import softmax
5352

5453

5554
_LOGGER = logging.getLogger(__name__)
@@ -481,7 +480,7 @@ def create_logprobs(
481480
tokens = pipeline.tokenizer.batch_decode(token_ids)
482481

483482
for i in range(len(tokens)):
484-
log_prob = float(numpy.log(max(softmax(scores[i]))))
483+
log_prob = float(numpy.log(max(numpy_softmax(scores[i]))))
485484
logprobs.tokens.append(tokens[i])
486485
logprobs.token_logprobs.append(log_prob)
487486

src/deepsparse/transformers/metrics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy
2222

23-
from scipy.special import log_softmax
23+
from deepsparse.utils import numpy_log_softmax
2424

2525

2626
__all__ = [
@@ -266,7 +266,7 @@ def _cross_entropy(
266266
float: The computed cross-entropy loss.
267267
"""
268268

269-
logp = log_softmax(predictions, axis=-1)
269+
logp = numpy_log_softmax(predictions, axis=-1)
270270
neg_log_likelihoods = -1.0 * numpy.take_along_axis(
271271
logp, numpy.expand_dims(targets, axis=-1), axis=-1
272272
)

src/deepsparse/utils/data.py

+30
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,36 @@ def numpy_softmax(x: numpy.ndarray, axis: int = 0):
169169
return softmax_x
170170

171171

172+
def numpy_log_softmax(x: numpy.ndarray, axis: int = 0):
173+
"""
174+
Ref: https://github.com/scipy/scipy/blob/v1.12.0/scipy/special/_logsumexp.py
175+
176+
In principle: log_softmax(x) = log(softmax(x))
177+
but using a more accurate implementation.
178+
179+
:param x: array containing values to be softmaxed
180+
:param axis: axis across which to perform softmax
181+
:return: x with values across axis softmaxed
182+
"""
183+
x_max = numpy.max(x, axis=axis, keepdims=True)
184+
185+
if x_max.ndim > 0:
186+
x_max[~numpy.isfinite(x_max)] = 0
187+
elif not numpy.isfinite(x_max):
188+
x_max = 0
189+
190+
tmp = x - x_max
191+
exp_tmp = numpy.exp(tmp)
192+
193+
# suppress warnings about log of zero
194+
with numpy.errstate(divide="ignore"):
195+
s = numpy.sum(exp_tmp, axis=axis, keepdims=True)
196+
out = numpy.log(s)
197+
198+
out = tmp - out
199+
return out
200+
201+
172202
def split_engine_inputs(
173203
items: List[numpy.ndarray], batch_size: int
174204
) -> Tuple[List[List[numpy.ndarray]], int]:

tests/server/test_openai.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
ModelPermission,
2525
OpenAIServer,
2626
)
27+
from deepsparse.utils import numpy_softmax
2728
from fastapi.testclient import TestClient
28-
from scipy.special import softmax
2929

3030

3131
TEST_MODEL_ID = "hf:mgoin/TinyStories-1M-ds"
@@ -246,7 +246,7 @@ def test_logprobs(client, model_card):
246246

247247
for local_gen, server_gen in zip(output.generations, response.json()["choices"]):
248248
local_top1_logprobs = [
249-
numpy.log(max(softmax(logits))) for logits in local_gen.score
249+
numpy.log(max(numpy_softmax(logits))) for logits in local_gen.score
250250
]
251251
server_top1_logprobs = server_gen["logprobs"]["token_logprobs"]
252252
assert numpy.allclose(local_top1_logprobs, server_top1_logprobs)

0 commit comments

Comments
 (0)