1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class Trie {
private static class TrieNode {
Map<Character, TrieNode> nodeMap;
int wordCount;
int prefixCount;

TrieNode(Map<Character, TrieNode> _nodeMap) {
this.nodeMap = _nodeMap;
}
}

private TrieNode root;

public Trie() {
root = new TrieNode(new HashMap<>());
}

public void insert(String word) {
TrieNode ptr = root;
for (int i = 0; i < word.length(); i++) {
ptr.nodeMap.putIfAbsent(word.charAt(i), new TrieNode(new HashMap<>()));
ptr = ptr.nodeMap.get(word.charAt(i));
ptr.prefixCount += 1;
}
ptr.wordCount += 1;
}

public int countWordsEqualTo(String word) {
TrieNode foundNode = findPrefix(word);
return foundNode != null ? foundNode.wordCount : 0;

}

public int countWordsStartingWith(String prefix) {
TrieNode foundNode = findPrefix(prefix);
return foundNode != null ? foundNode.prefixCount : 0;
}

public void erase(String word) {
TrieNode ptr = root;
for (int i = 0; i < word.length(); i++) {
ptr = ptr.nodeMap.get(word.charAt(i));
ptr.prefixCount -= 1;
}
ptr.wordCount -= 1;
}

private TrieNode findPrefix(String prefix) {
TrieNode ptr = root;
for (int i = 0; i < prefix.length(); i++) {
if (!ptr.nodeMap.containsKey(prefix.charAt(i))) {
return null;
}
ptr = ptr.nodeMap.get(prefix.charAt(i));
}
return ptr;
}
}

唯一不同的就是要给每个node一个wordCount field, 表示多少个word在这里结尾以及prefix count, 表示有多少个word有这个prefix.