Files
needLe/dotnet/MaigoLabs.NeedLe.Searcher/Trie/TrieDeserializer.cs
T
2026-01-01 03:40:41 +08:00

74 lines
2.8 KiB
C#

using MaigoLabs.NeedLe.Common;
namespace MaigoLabs.NeedLe.Searcher.Trie;
public class DeserializedTrie
{
public required TrieNode Root { get; set; }
public required Dictionary<int, int[]> TokenCodePoints { get; set; }
}
public static class TrieDeserializer
{
public static DeserializedTrie Deserialize(int[] data)
{
var nodes = new List<TrieNode?>();
TrieNode GetNode(int id)
{
if (id > nodes.Count) nodes.AddRange(Enumerable.Repeat<TrieNode?>(null, id - nodes.Count));
return nodes[id - 1] ??= new TrieNode { Parent = null, Children = [], TokenIds = [], SubTreeTokenIds = [] };
}
var currentId = 0;
for (var i = 0; i < data.Length; )
{
var node = GetNode(++currentId);
var parentId = data[i++];
node.Parent = parentId != 0 ? GetNode(parentId) : null;
var endOfChildren = i;
while (endOfChildren < data.Length && data[endOfChildren] > 0) endOfChildren++;
var numberOfChildren = (endOfChildren - i) / 2;
for (var j = i; j < i + numberOfChildren; j++)
{
var codePoint = data[j];
var child = GetNode(data[j + numberOfChildren]);
node.Children.Add(codePoint, child);
}
i = endOfChildren;
if (data[i] == 0) i++; // No token IDs
else while (i < data.Length && data[i] < 0) node.TokenIds.Add(-data[i++] - 1);
}
var root = nodes[0]!;
// DFS to construct code point paths for each token
var tokenCodePoints = new Dictionary<int, int[]>();
var currentCodePoints = new List<int>();
void DfsCodePoints(TrieNode node)
{
foreach (var tokenId in node.TokenIds) tokenCodePoints.Add(tokenId, [.. currentCodePoints]);
foreach (var (codePoint, child) in node.Children)
{
if (child.Parent != node) continue; // Skip grafted paths as these are not the canonical representation of the tokens
currentCodePoints.Add(codePoint);
DfsCodePoints(child);
currentCodePoints.RemoveAt(currentCodePoints.Count - 1);
}
}
DfsCodePoints(root);
// DFS to construct subTreeTokenIds for each node
var visitedNodes = new HashSet<TrieNode>();
List<int> DfsSubTreeTokenIds(TrieNode node)
{
if (visitedNodes.Contains(node)) return node.SubTreeTokenIds;
visitedNodes.Add(node);
node.SubTreeTokenIds = new HashSet<int>(node.TokenIds.Concat(node.Children.Values.SelectMany(DfsSubTreeTokenIds))).ToList();
return node.SubTreeTokenIds;
};
DfsSubTreeTokenIds(root);
return new DeserializedTrie { Root = root, TokenCodePoints = tokenCodePoints };
}
}