feat: pre-filter and custom comparer for documents

This commit is contained in:
Menci
2026-01-06 22:34:25 +08:00
parent 1a08454351
commit 63dcd3fb75
4 changed files with 134 additions and 8 deletions
@@ -24,6 +24,21 @@ public class SearchResult
public required int MatchRatioLevel { get; set; }
}
public class InvertedIndexSearcherOptions
{
/// <summary>
/// Called when all other comparisons are equal.
/// <code>(int documentIdA, int documentIdB) => int</code>
/// </summary>
public Func<int, int, int>? NextComparer { get; set; }
/// <summary>
/// If return false for a document, it will be excluded from the final results.
/// <code>(int documentId) => bool</code>
/// </summary>
public Func<int, bool>? FilterDocument { get; set; }
}
public static class InvertedIndexSearcher
{
public abstract class ComparableStateBase<T> : IComparable<T>
@@ -36,7 +51,6 @@ public static class InvertedIndexSearcher
protected virtual SearchResultToken? GetLastToken() => null; // Not on intermediate results
protected virtual int? GetMatchRatioLevel() => null; // Not on intermediate/candidate results
protected abstract double GetMatchRatio();
protected virtual int FallbackCompareTo(T other) => 0; // Called when all other comparisons are equal
public int CompareTo(T other)
{
@@ -73,7 +87,7 @@ public static class InvertedIndexSearcher
double aMatchRatio = GetMatchRatio(), bMatchRatio = other.GetMatchRatio();
if (aMatchRatio != bMatchRatio) return bMatchRatio < aMatchRatio ? -1 : bMatchRatio > aMatchRatio ? 1 : 0;
return FallbackCompareTo(other);
return 0;
}
}
@@ -123,7 +137,7 @@ public static class InvertedIndexSearcher
protected override SearchResultToken? GetLastToken() => Result.Tokens[^1];
protected override double GetMatchRatio() => Result.MatchRatio;
protected override int? GetMatchRatioLevel() => Result.MatchRatioLevel;
protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture);
// protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture);
}
private static bool IsIgnorableCodePoint(int codePoint) => CommonUtils.IsWhitespace(codePoint) || codePoint == 0x3099 || codePoint == 0x309A;
@@ -151,7 +165,7 @@ public static class InvertedIndexSearcher
private static bool HasNonEmptyCharacters(int[] documentCodePoints, int start, int end) =>
start != end && !documentCodePoints.Skip(start).Take(end - start).All(CommonUtils.IsWhitespace);
public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text)
public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text, InvertedIndexSearcherOptions? options = null)
{
var documents = invertedIndex.Documents;
var documentCodePoints = invertedIndex.DocumentCodePoints;
@@ -183,6 +197,7 @@ public static class InvertedIndexSearcher
];
foreach (var tokenId in matchingTokenIds) foreach (var reference in tokenDefinitions[tokenId].References)
{
if (options?.FilterDocument != null && !options.FilterDocument(reference.DocumentId)) continue;
var isTokenPrefixMatching = !romajiNode.IsTokenExactMatch(tokenId) && !kanaNode.IsTokenExactMatch(tokenId) && !otherNode.IsTokenExactMatch(tokenId);
var previousMatchesOfDocument = l != 0 && dp[l - 1].TryGetValue(reference.DocumentId, out var previousMatches) ? previousMatches : null;
if (l != 0 && previousMatchesOfDocument == null) continue;
@@ -265,6 +280,13 @@ public static class InvertedIndexSearcher
MatchRatioLevel = matchRatioLevel,
}
};
}).OrderBy(result => result).Select(result => result.Result).ToArray();
}).OrderBy(result => result, Comparer<FinalResult>.Create((a, b) =>
{
var compareResult = a.CompareTo(b);
if (compareResult != 0) return compareResult;
return options?.NextComparer == null
? string.Compare(a.Result.DocumentText, b.Result.DocumentText, StringComparison.InvariantCulture)
: options.NextComparer(a.Result.DocumentId, b.Result.DocumentId);
})).Select(result => result.Result).ToArray();
}
}
@@ -139,3 +139,54 @@ public sealed class Search_BundleDocumentsOption_ThrowsWhenNoneProvidedTest : Ne
Assert.Throws<ArgumentException>(() => InvertedIndexLoader.Load(compressed));
}
}
public sealed class Search_FilterDocumentOption_ExcludesFilteredDocumentsTest : NeedleTestBase
{
private static readonly string[] TestDocuments =
[
"ミーティア",
"エンドマークに希望と涙を添えて",
"宵の鳥",
"僕の和風本当上手",
];
[Fact]
public void Execute()
{
var compressed = InvertedIndexBuilder.BuildInvertedIndex(TestDocuments, TokenizerOptions);
var invertedIndex = InvertedIndexLoader.Load(compressed);
// Search without filter - should find "宵の鳥" (documentId 2)
var resultsWithoutFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi");
Assert.Contains("宵の鳥", resultsWithoutFilter.Select(r => r.DocumentText));
// Search with filter excluding documentId 2
var resultsWithFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi", new InvertedIndexSearcherOptions
{
FilterDocument = id => id != 2
});
Assert.DoesNotContain("宵の鳥", resultsWithFilter.Select(r => r.DocumentText));
}
}
public sealed class Search_NextComparerOption_UsesCustomComparerTest : NeedleTestBase
{
[Fact]
public void Execute()
{
// Create documents that would have similar match scores
var similarDocs = new[] { "テストA", "テストB", "テストC" };
var compressed = InvertedIndexBuilder.BuildInvertedIndex(similarDocs, TokenizerOptions);
var invertedIndex = InvertedIndexLoader.Load(compressed);
// Search with reverse order comparer
var results = InvertedIndexSearcher.Search(invertedIndex, "テスト", new InvertedIndexSearcherOptions
{
NextComparer = (a, b) => b - a // Reverse by documentId
});
// Should be in reverse documentId order (2, 1, 0) when other criteria equal
Assert.Equal([2, 1, 0], results.Select(r => r.DocumentId).ToArray());
}
}