Skip to content

Commit 713590a

Browse files
committed
Fixed SparseVector constructor for SciPy sparse matrices - fixes #127
1 parent 3f9e9a2 commit 713590a

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.4.1 (unreleased)
2+
3+
- Fixed `SparseVector` constructor for SciPy sparse matrices
4+
15
## 0.4.0 (2025-03-15)
26

37
- Added top-level `pgvector` package

pgvector/sparsevec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _from_sparse(self, value):
8585

8686
if hasattr(value, 'coords'):
8787
# scipy 1.13+
88-
self._indices = value.coords[0].tolist()
88+
self._indices = value.coords[-1].tolist()
8989
else:
9090
self._indices = value.col.tolist()
9191
self._values = value.data.tolist()

tests/test_sparse_vector.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from pgvector import SparseVector
33
import pytest
4-
from scipy.sparse import coo_array
4+
from scipy.sparse import coo_array, csr_array, csr_matrix
55
from struct import pack
66

77

@@ -49,6 +49,18 @@ def test_dok_array(self):
4949
assert vec.to_list() == [1, 0, 2, 0, 3, 0]
5050
assert vec.indices() == [0, 2, 4]
5151

52+
def test_csr_array(self):
53+
arr = csr_array(np.array([1, 0, 2, 0, 3, 0]))
54+
vec = SparseVector(arr)
55+
assert vec.to_list() == [1, 0, 2, 0, 3, 0]
56+
assert vec.indices() == [0, 2, 4]
57+
58+
def test_csr_matrix(self):
59+
mat = csr_matrix(np.array([1, 0, 2, 0, 3, 0]))
60+
vec = SparseVector(mat)
61+
assert vec.to_list() == [1, 0, 2, 0, 3, 0]
62+
assert vec.indices() == [0, 2, 4]
63+
5264
def test_repr(self):
5365
assert repr(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)'
5466
assert str(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)'

0 commit comments

Comments
 (0)