diff --git a/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs b/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs
index 04e8ff1..e7c8956 100644
--- a/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs
+++ b/dotnet/MaigoLabs.NeedLe.Searcher/InvertedIndexSearcher.cs
@@ -24,6 +24,21 @@ public class SearchResult
public required int MatchRatioLevel { get; set; }
}
+public class InvertedIndexSearcherOptions
+{
+ ///
+ /// Called when all other comparisons are equal.
+ /// (int documentIdA, int documentIdB) => int
+ ///
+ public Func? NextComparer { get; set; }
+
+ ///
+ /// If return false for a document, it will be excluded from the final results.
+ /// (int documentId) => bool
+ ///
+ public Func? FilterDocument { get; set; }
+}
+
public static class InvertedIndexSearcher
{
public abstract class ComparableStateBase : IComparable
@@ -36,7 +51,6 @@ public static class InvertedIndexSearcher
protected virtual SearchResultToken? GetLastToken() => null; // Not on intermediate results
protected virtual int? GetMatchRatioLevel() => null; // Not on intermediate/candidate results
protected abstract double GetMatchRatio();
- protected virtual int FallbackCompareTo(T other) => 0; // Called when all other comparisons are equal
public int CompareTo(T other)
{
@@ -73,7 +87,7 @@ public static class InvertedIndexSearcher
double aMatchRatio = GetMatchRatio(), bMatchRatio = other.GetMatchRatio();
if (aMatchRatio != bMatchRatio) return bMatchRatio < aMatchRatio ? -1 : bMatchRatio > aMatchRatio ? 1 : 0;
- return FallbackCompareTo(other);
+ return 0;
}
}
@@ -123,7 +137,7 @@ public static class InvertedIndexSearcher
protected override SearchResultToken? GetLastToken() => Result.Tokens[^1];
protected override double GetMatchRatio() => Result.MatchRatio;
protected override int? GetMatchRatioLevel() => Result.MatchRatioLevel;
- protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture);
+ // protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture);
}
private static bool IsIgnorableCodePoint(int codePoint) => CommonUtils.IsWhitespace(codePoint) || codePoint == 0x3099 || codePoint == 0x309A;
@@ -151,7 +165,7 @@ public static class InvertedIndexSearcher
private static bool HasNonEmptyCharacters(int[] documentCodePoints, int start, int end) =>
start != end && !documentCodePoints.Skip(start).Take(end - start).All(CommonUtils.IsWhitespace);
- public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text)
+ public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text, InvertedIndexSearcherOptions? options = null)
{
var documents = invertedIndex.Documents;
var documentCodePoints = invertedIndex.DocumentCodePoints;
@@ -183,6 +197,7 @@ public static class InvertedIndexSearcher
];
foreach (var tokenId in matchingTokenIds) foreach (var reference in tokenDefinitions[tokenId].References)
{
+ if (options?.FilterDocument != null && !options.FilterDocument(reference.DocumentId)) continue;
var isTokenPrefixMatching = !romajiNode.IsTokenExactMatch(tokenId) && !kanaNode.IsTokenExactMatch(tokenId) && !otherNode.IsTokenExactMatch(tokenId);
var previousMatchesOfDocument = l != 0 && dp[l - 1].TryGetValue(reference.DocumentId, out var previousMatches) ? previousMatches : null;
if (l != 0 && previousMatchesOfDocument == null) continue;
@@ -265,6 +280,13 @@ public static class InvertedIndexSearcher
MatchRatioLevel = matchRatioLevel,
}
};
- }).OrderBy(result => result).Select(result => result.Result).ToArray();
+ }).OrderBy(result => result, Comparer.Create((a, b) =>
+ {
+ var compareResult = a.CompareTo(b);
+ if (compareResult != 0) return compareResult;
+ return options?.NextComparer == null
+ ? string.Compare(a.Result.DocumentText, b.Result.DocumentText, StringComparison.InvariantCulture)
+ : options.NextComparer(a.Result.DocumentId, b.Result.DocumentId);
+ })).Select(result => result.Result).ToArray();
}
}
diff --git a/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs b/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs
index 66d8e51..817d8be 100644
--- a/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs
+++ b/dotnet/MaigoLabs.NeedLe.Tests/E2E/SearchTests.cs
@@ -139,3 +139,54 @@ public sealed class Search_BundleDocumentsOption_ThrowsWhenNoneProvidedTest : Ne
Assert.Throws(() => InvertedIndexLoader.Load(compressed));
}
}
+
+public sealed class Search_FilterDocumentOption_ExcludesFilteredDocumentsTest : NeedleTestBase
+{
+ private static readonly string[] TestDocuments =
+ [
+ "ミーティア",
+ "エンドマークに希望と涙を添えて",
+ "宵の鳥",
+ "僕の和風本当上手",
+ ];
+
+ [Fact]
+ public void Execute()
+ {
+ var compressed = InvertedIndexBuilder.BuildInvertedIndex(TestDocuments, TokenizerOptions);
+ var invertedIndex = InvertedIndexLoader.Load(compressed);
+
+ // Search without filter - should find "宵の鳥" (documentId 2)
+ var resultsWithoutFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi");
+ Assert.Contains("宵の鳥", resultsWithoutFilter.Select(r => r.DocumentText));
+
+ // Search with filter excluding documentId 2
+ var resultsWithFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi", new InvertedIndexSearcherOptions
+ {
+ FilterDocument = id => id != 2
+ });
+ Assert.DoesNotContain("宵の鳥", resultsWithFilter.Select(r => r.DocumentText));
+ }
+}
+
+public sealed class Search_NextComparerOption_UsesCustomComparerTest : NeedleTestBase
+{
+ [Fact]
+ public void Execute()
+ {
+ // Create documents that would have similar match scores
+ var similarDocs = new[] { "テストA", "テストB", "テストC" };
+ var compressed = InvertedIndexBuilder.BuildInvertedIndex(similarDocs, TokenizerOptions);
+ var invertedIndex = InvertedIndexLoader.Load(compressed);
+
+ // Search with reverse order comparer
+ var results = InvertedIndexSearcher.Search(invertedIndex, "テスト", new InvertedIndexSearcherOptions
+ {
+ NextComparer = (a, b) => b - a // Reverse by documentId
+ });
+
+ // Should be in reverse documentId order (2, 1, 0) when other criteria equal
+ Assert.Equal([2, 1, 0], results.Select(r => r.DocumentId).ToArray());
+ }
+}
+
diff --git a/packages/needle/src/e2e/search.test.ts b/packages/needle/src/e2e/search.test.ts
index b089d31..bee68d9 100644
--- a/packages/needle/src/e2e/search.test.ts
+++ b/packages/needle/src/e2e/search.test.ts
@@ -100,4 +100,38 @@ describe('search options', () => {
expect(() => loadInvertedIndex(compressed)).toThrow();
});
});
+
+ describe('filterDocument option', () => {
+ it('should exclude filtered documents from results', () => {
+ const compressed = buildInvertedIndex(testDocuments, { kuromoji });
+ const invertedIndex = loadInvertedIndex(compressed);
+
+ // Search without filter - should find "宵の鳥" (documentId 2)
+ const resultsWithoutFilter = searchInvertedIndex(invertedIndex, 'yoi');
+ expect(resultsWithoutFilter.map(r => r.documentText)).toContain('宵の鳥');
+
+ // Search with filter excluding documentId 2
+ const resultsWithFilter = searchInvertedIndex(invertedIndex, 'yoi', {
+ filterDocument: id => id !== 2,
+ });
+ expect(resultsWithFilter.map(r => r.documentText)).not.toContain('宵の鳥');
+ });
+ });
+
+ describe('nextComparer option', () => {
+ it('should use custom comparer for final sorting when other criteria are equal', () => {
+ // Create documents that would have similar match scores
+ const similarDocs = ['テストA', 'テストB', 'テストC'];
+ const compressed = buildInvertedIndex(similarDocs, { kuromoji });
+ const invertedIndex = loadInvertedIndex(compressed);
+
+ // Search with reverse order comparer
+ const results = searchInvertedIndex(invertedIndex, 'テスト', {
+ nextComparer: (a, b) => b - a, // Reverse by documentId
+ });
+
+ // Should be in reverse documentId order (2, 1, 0) when other criteria equal
+ expect(results.map(r => r.documentId)).toEqual([2, 1, 0]);
+ });
+ });
});
diff --git a/packages/needle/src/searcher/search.ts b/packages/needle/src/searcher/search.ts
index 9499916..575fec7 100644
--- a/packages/needle/src/searcher/search.ts
+++ b/packages/needle/src/searcher/search.ts
@@ -133,12 +133,24 @@ const compareFinalResult = getComparerForTraits({
getLastToken: state => state.tokens[state.tokens.length - 1]!,
getMatchRatio: state => state.matchRatio,
getMatchRatioLevel: state => Math.round(state.matchRatio * 5),
- nextComparer: (a, b) => a.documentText === b.documentText ? 0 : a.documentText < b.documentText ? -1 : 1,
});
const hasNonEmptyCharacters = (documentCodePoints: string[], start: number, end: number) => start !== end && !documentCodePoints.slice(start, end).every(char => /\s/.test(char));
-export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: string): SearchResult[] => {
+export const searchInvertedIndex = (
+ invertedIndex: LoadedInvertedIndex,
+ text: string,
+ options?: {
+ /**
+ * Called when all other comparisons are equal.
+ */
+ nextComparer?: (documentIdA: number, documentIdB: number) => number;
+ /**
+ * If return falsy value for a document, it will be excluded from the final results.
+ */
+ filterDocument?: (documentId: number) => unknown;
+ },
+): SearchResult[] => {
const { documents, documentCodePoints, tokenDefinitions, tries } = invertedIndex;
const codePoints = [...toKatakana(normalizeByCodePoint(text))];
@@ -162,6 +174,7 @@ export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: st
...getTrieNodeTokenIds(otherNode, reachingInputEnd),
]);
for (const tokenId of matchingTokenIds) for (const { documentId, offsets } of tokenDefinitions[tokenId]!.references) {
+ if (options?.filterDocument && !options.filterDocument(documentId)) continue;
const isTokenPrefixMatching = !romajiNode?.tokenIds.includes(tokenId) && !kanaNode?.tokenIds.includes(tokenId) && !otherNode?.tokenIds.includes(tokenId);
const previousMatchesOfDocument = dp[l - 1]?.get(documentId);
if (l !== 0 && !previousMatchesOfDocument) continue;
@@ -231,7 +244,13 @@ export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: st
matchRatio,
matchRatioLevel,
};
- }).sort(compareFinalResult);
+ }).sort((a, b) => {
+ const compareResult = compareFinalResult(a, b);
+ if (compareResult !== 0) return compareResult;
+ return options?.nextComparer
+ ? options.nextComparer(a.documentId, b.documentId)
+ : a.documentText === b.documentText ? 0 : a.documentText < b.documentText ? -1 : 1;
+ });
};
// For debugging