图(稀疏矩阵)的压缩存储 CSR

图(稀疏矩阵)的压缩存储 CSR

矩阵的压缩

CSR 是对稀疏矩阵进行压缩处理,把矩阵压缩成三个数组,值数组(values)、行索引数组(col_index)和列偏移数组(row_offsets)。

  • 值数组:按照矩阵的行顺序依次存储非零元素的值
  • 行索引数组:行索引数组存储了每个非零元素所在的行号 (col_index[i]等于values[i]元素的列数)
  • 列偏移数组:存储了每一行的非零元素在值数组中的起始位置 (第i个元素记录了前i-1行包含的非零元素的数量)

 

下面举个例子

对于矩阵:

1
2
3
4
[[4, 0, 0, 2],
[0, 1, 0, 0],
[0, 0, 5, 7],
[6, 3, 0, 8]]
  • values: [4, 2, 1, 5, 7, 6, 3, 8]
  • col_index: [0, 3, 1, 2, 3, 0, 1, 3]
  • row_offsets: [0, 2, 3, 5, 8]

矩阵的解压

通过row_offsets找到对应的范围,通过col_index确定列,通过values取出值。

以一个例子说明,还是上面的矩阵,假设取[1][1]的值

  • 通过row_offset找到范围 为 [row_offsets[1],row_offset[1+1])即是[2,3)
  • 遍历的col_index的[2,3),找到col_index[i]==1,即i==2
  • 直接返回values[i]即values[2],也就是1

如果第二步的时候找不到,那么就直接返回0。

对于代码来说就是

1
2
3
4
5
6
7
8

def get(self, row: int, col: int) -> int:
start = self.row_offsets[row]
end = self.row_offsets[row + 1]
for i in range(start, end):
if self.col_index[i] == col:
return self.values[i]
return 0

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import List


class CSR:
def __init__(self, matrix: List[List[int]]):
# 矩阵的行顺序依次存储非零元素的值
self.values = []
# 行索引数组存储了每个非零元素所在的行号
self.col_index = []
# 列偏移数组存储了每一行的非零元素在值数组中的起始位置
self.row_offsets = [0]

self._column_len = len(matrix[0])

for i in range(0, len(matrix)):
row_offset = self.row_offsets[-1]
for j in range(0, len(matrix[i])):
if matrix[i][j] != 0:
row_offset += 1
self.values.append(matrix[i][j])
self.col_index.append(j)
self.row_offsets.append(row_offset)

def get(self, row: int, col: int) -> int:

start = self.row_offsets[row]
end = self.row_offsets[row + 1]

for i in range(start, end):
if self.col_index[i] == col:
return self.values[i]
return 0

def column_len(self) -> int:
return self._column_len

def row_len(self) -> int:
return len(self.row_offsets) - 1


m = [[4, 0, 0, 2],
[0, 1, 0, 0],
[0, 0, 5, 7],
[6, 3, 0, 8]]
csr = CSR(m)
print('values:')
print(csr.values)
print('col_index:')
print(csr.col_index)
print('row_offsets:')
print(csr.row_offsets)

print("decompression:")
d_m = []
for i in range(0, csr.row_len()):
a = []
for j in range(0, csr.column_len()):
a.append(csr.get(i, j))
d_m.append(a)
print(d_m)