1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| class SparseVector { public List<int[]> compressed;
SparseVector(int[] nums) { compressed = new ArrayList<>(); for (int i = 0; i < nums.length; i++) { compressed.add(new int[] { i, nums[i] }); } }
public int dotProduct(SparseVector vec) { int ans = 0; int ptrOne = 0, ptrTwo = 0; while (ptrOne < this.compressed.size() && ptrTwo < vec.compressed.size()) { if (this.compressed.get(ptrOne)[0] == vec.compressed.get(ptrTwo)[0]) { ans += this.compressed.get(ptrOne)[1] * vec.compressed.get(ptrTwo)[1]; ptrOne += 1; ptrTwo += 1; } else if (this.compressed.get(ptrOne)[0] < vec.compressed.get(ptrTwo)[0]) { ptrOne += 1; } else { ptrTwo += 1; } } return ans; } }
|
精髓在于把non-zero的数字存到一个pair中. 该pair中第0个元素是这个数字在原array中的index, 该pair第1个元素是该数字的值. 其他就没什么好说的了.
n表示原数组长度, L1表示生成的list的长度(非0的数字的个数), L2表示传进来的vec对应的list的长度.
时间复杂度: 创建list是O(n). dot product是O(L1 + L2)
空间复杂度: O(L1)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| class SparseVector { public Map<Integer, Integer> compressed;
SparseVector(int[] nums) { compressed = new HashMap<>(); for (int i = 0; i < nums.length; i++) { if (nums[i] != 0) { compressed.put(i, nums[i]); } } }
public int dotProduct(SparseVector vec) { int ans = 0; for (Map.Entry<Integer, Integer> entry : compressed.entrySet()) { ans += entry.getValue() * vec.compressed.getOrDefault(entry.getKey(), 0); } return ans; } }
|
用HashMap也行. 当来一个vec的时候, 遍历map每一个entry, 看来的这个vec中map中对应的index有没有值. 有的话就乘上去加到ans上. 没有的话就返回0乘上去加到ans上去.