From 63dcd3fb756a97fefd49d43f4a8c8c93dd008d76 Mon Sep 17 00:00:00 2001 From: Menci Date: Tue, 6 Jan 2026 22:34:25 +0800 Subject: [PATCH] feat: pre-filter and custom comparer for documents --- .../InvertedIndexSearcher.cs | 32 ++++++++++-- .../MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs | 51 +++++++++++++++++++ packages/needle/src/e2e/search.test.ts | 34 +++++++++++++ packages/needle/src/searcher/search.ts | 25 +++++++-- 4 files changed, 134 insertions(+), 8 deletions(-) diff --git a/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs b/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs index 04e8ff1..e7c8956 100644 --- a/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs +++ b/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs @@ -24,6 +24,21 @@ public class SearchResult public required int MatchRatioLevel { get; set; } } +public class InvertedIndexSearcherOptions +{ + /// + /// Called when all other comparisons are equal. + /// (int documentIdA, int documentIdB) => int + /// + public Func? NextComparer { get; set; } + + /// + /// If return false for a document, it will be excluded from the final results. + /// (int documentId) => bool + /// + public Func? FilterDocument { get; set; } +} + public static class InvertedIndexSearcher { public abstract class ComparableStateBase : IComparable @@ -36,7 +51,6 @@ public static class InvertedIndexSearcher protected virtual SearchResultToken? GetLastToken() => null; // Not on intermediate results protected virtual int? GetMatchRatioLevel() => null; // Not on intermediate/candidate results protected abstract double GetMatchRatio(); - protected virtual int FallbackCompareTo(T other) => 0; // Called when all other comparisons are equal public int CompareTo(T other) { @@ -73,7 +87,7 @@ public static class InvertedIndexSearcher double aMatchRatio = GetMatchRatio(), bMatchRatio = other.GetMatchRatio(); if (aMatchRatio != bMatchRatio) return bMatchRatio < aMatchRatio ? -1 : bMatchRatio > aMatchRatio ? 1 : 0; - return FallbackCompareTo(other); + return 0; } } @@ -123,7 +137,7 @@ public static class InvertedIndexSearcher protected override SearchResultToken? GetLastToken() => Result.Tokens[^1]; protected override double GetMatchRatio() => Result.MatchRatio; protected override int? GetMatchRatioLevel() => Result.MatchRatioLevel; - protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture); + // protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture); } private static bool IsIgnorableCodePoint(int codePoint) => CommonUtils.IsWhitespace(codePoint) || codePoint == 0x3099 || codePoint == 0x309A; @@ -151,7 +165,7 @@ public static class InvertedIndexSearcher private static bool HasNonEmptyCharacters(int[] documentCodePoints, int start, int end) => start != end && !documentCodePoints.Skip(start).Take(end - start).All(CommonUtils.IsWhitespace); - public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text) + public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text, InvertedIndexSearcherOptions? options = null) { var documents = invertedIndex.Documents; var documentCodePoints = invertedIndex.DocumentCodePoints; @@ -183,6 +197,7 @@ public static class InvertedIndexSearcher ]; foreach (var tokenId in matchingTokenIds) foreach (var reference in tokenDefinitions[tokenId].References) { + if (options?.FilterDocument != null && !options.FilterDocument(reference.DocumentId)) continue; var isTokenPrefixMatching = !romajiNode.IsTokenExactMatch(tokenId) && !kanaNode.IsTokenExactMatch(tokenId) && !otherNode.IsTokenExactMatch(tokenId); var previousMatchesOfDocument = l != 0 && dp[l - 1].TryGetValue(reference.DocumentId, out var previousMatches) ? previousMatches : null; if (l != 0 && previousMatchesOfDocument == null) continue; @@ -265,6 +280,13 @@ public static class InvertedIndexSearcher MatchRatioLevel = matchRatioLevel, } }; - }).OrderBy(result => result).Select(result => result.Result).ToArray(); + }).OrderBy(result => result, Comparer.Create((a, b) => + { + var compareResult = a.CompareTo(b); + if (compareResult != 0) return compareResult; + return options?.NextComparer == null + ? string.Compare(a.Result.DocumentText, b.Result.DocumentText, StringComparison.InvariantCulture) + : options.NextComparer(a.Result.DocumentId, b.Result.DocumentId); + })).Select(result => result.Result).ToArray(); } } diff --git a/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs b/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs index 66d8e51..817d8be 100644 --- a/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs +++ b/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs @@ -139,3 +139,54 @@ public sealed class Search_BundleDocumentsOption_ThrowsWhenNoneProvidedTest : Ne Assert.Throws(() => InvertedIndexLoader.Load(compressed)); } } + +public sealed class Search_FilterDocumentOption_ExcludesFilteredDocumentsTest : NeedleTestBase +{ + private static readonly string[] TestDocuments = + [ + "ミーティア", + "エンドマークに希望と涙を添えて", + "宵の鳥", + "僕の和風本当上手", + ]; + + [Fact] + public void Execute() + { + var compressed = InvertedIndexBuilder.BuildInvertedIndex(TestDocuments, TokenizerOptions); + var invertedIndex = InvertedIndexLoader.Load(compressed); + + // Search without filter - should find "宵の鳥" (documentId 2) + var resultsWithoutFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi"); + Assert.Contains("宵の鳥", resultsWithoutFilter.Select(r => r.DocumentText)); + + // Search with filter excluding documentId 2 + var resultsWithFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi", new InvertedIndexSearcherOptions + { + FilterDocument = id => id != 2 + }); + Assert.DoesNotContain("宵の鳥", resultsWithFilter.Select(r => r.DocumentText)); + } +} + +public sealed class Search_NextComparerOption_UsesCustomComparerTest : NeedleTestBase +{ + [Fact] + public void Execute() + { + // Create documents that would have similar match scores + var similarDocs = new[] { "テストA", "テストB", "テストC" }; + var compressed = InvertedIndexBuilder.BuildInvertedIndex(similarDocs, TokenizerOptions); + var invertedIndex = InvertedIndexLoader.Load(compressed); + + // Search with reverse order comparer + var results = InvertedIndexSearcher.Search(invertedIndex, "テスト", new InvertedIndexSearcherOptions + { + NextComparer = (a, b) => b - a // Reverse by documentId + }); + + // Should be in reverse documentId order (2, 1, 0) when other criteria equal + Assert.Equal([2, 1, 0], results.Select(r => r.DocumentId).ToArray()); + } +} + diff --git a/packages/needle/src/e2e/search.test.ts b/packages/needle/src/e2e/search.test.ts index b089d31..bee68d9 100644 --- a/packages/needle/src/e2e/search.test.ts +++ b/packages/needle/src/e2e/search.test.ts @@ -100,4 +100,38 @@ describe('search options', () => { expect(() => loadInvertedIndex(compressed)).toThrow(); }); }); + + describe('filterDocument option', () => { + it('should exclude filtered documents from results', () => { + const compressed = buildInvertedIndex(testDocuments, { kuromoji }); + const invertedIndex = loadInvertedIndex(compressed); + + // Search without filter - should find "宵の鳥" (documentId 2) + const resultsWithoutFilter = searchInvertedIndex(invertedIndex, 'yoi'); + expect(resultsWithoutFilter.map(r => r.documentText)).toContain('宵の鳥'); + + // Search with filter excluding documentId 2 + const resultsWithFilter = searchInvertedIndex(invertedIndex, 'yoi', { + filterDocument: id => id !== 2, + }); + expect(resultsWithFilter.map(r => r.documentText)).not.toContain('宵の鳥'); + }); + }); + + describe('nextComparer option', () => { + it('should use custom comparer for final sorting when other criteria are equal', () => { + // Create documents that would have similar match scores + const similarDocs = ['テストA', 'テストB', 'テストC']; + const compressed = buildInvertedIndex(similarDocs, { kuromoji }); + const invertedIndex = loadInvertedIndex(compressed); + + // Search with reverse order comparer + const results = searchInvertedIndex(invertedIndex, 'テスト', { + nextComparer: (a, b) => b - a, // Reverse by documentId + }); + + // Should be in reverse documentId order (2, 1, 0) when other criteria equal + expect(results.map(r => r.documentId)).toEqual([2, 1, 0]); + }); + }); }); diff --git a/packages/needle/src/searcher/search.ts b/packages/needle/src/searcher/search.ts index 9499916..575fec7 100644 --- a/packages/needle/src/searcher/search.ts +++ b/packages/needle/src/searcher/search.ts @@ -133,12 +133,24 @@ const compareFinalResult = getComparerForTraits({ getLastToken: state => state.tokens[state.tokens.length - 1]!, getMatchRatio: state => state.matchRatio, getMatchRatioLevel: state => Math.round(state.matchRatio * 5), - nextComparer: (a, b) => a.documentText === b.documentText ? 0 : a.documentText < b.documentText ? -1 : 1, }); const hasNonEmptyCharacters = (documentCodePoints: string[], start: number, end: number) => start !== end && !documentCodePoints.slice(start, end).every(char => /\s/.test(char)); -export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: string): SearchResult[] => { +export const searchInvertedIndex = ( + invertedIndex: LoadedInvertedIndex, + text: string, + options?: { + /** + * Called when all other comparisons are equal. + */ + nextComparer?: (documentIdA: number, documentIdB: number) => number; + /** + * If return falsy value for a document, it will be excluded from the final results. + */ + filterDocument?: (documentId: number) => unknown; + }, +): SearchResult[] => { const { documents, documentCodePoints, tokenDefinitions, tries } = invertedIndex; const codePoints = [...toKatakana(normalizeByCodePoint(text))]; @@ -162,6 +174,7 @@ export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: st ...getTrieNodeTokenIds(otherNode, reachingInputEnd), ]); for (const tokenId of matchingTokenIds) for (const { documentId, offsets } of tokenDefinitions[tokenId]!.references) { + if (options?.filterDocument && !options.filterDocument(documentId)) continue; const isTokenPrefixMatching = !romajiNode?.tokenIds.includes(tokenId) && !kanaNode?.tokenIds.includes(tokenId) && !otherNode?.tokenIds.includes(tokenId); const previousMatchesOfDocument = dp[l - 1]?.get(documentId); if (l !== 0 && !previousMatchesOfDocument) continue; @@ -231,7 +244,13 @@ export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: st matchRatio, matchRatioLevel, }; - }).sort(compareFinalResult); + }).sort((a, b) => { + const compareResult = compareFinalResult(a, b); + if (compareResult !== 0) return compareResult; + return options?.nextComparer + ? options.nextComparer(a.documentId, b.documentId) + : a.documentText === b.documentText ? 0 : a.documentText < b.documentText ? -1 : 1; + }); }; // For debugging