feat: pre-filter and custom comparer for documents
This commit is contained in:
@@ -24,6 +24,21 @@ public class SearchResult
|
|||||||
public required int MatchRatioLevel { get; set; }
|
public required int MatchRatioLevel { get; set; }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class InvertedIndexSearcherOptions
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Called when all other comparisons are equal.
|
||||||
|
/// <code>(int documentIdA, int documentIdB) => int</code>
|
||||||
|
/// </summary>
|
||||||
|
public Func<int, int, int>? NextComparer { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// If return false for a document, it will be excluded from the final results.
|
||||||
|
/// <code>(int documentId) => bool</code>
|
||||||
|
/// </summary>
|
||||||
|
public Func<int, bool>? FilterDocument { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
public static class InvertedIndexSearcher
|
public static class InvertedIndexSearcher
|
||||||
{
|
{
|
||||||
public abstract class ComparableStateBase<T> : IComparable<T>
|
public abstract class ComparableStateBase<T> : IComparable<T>
|
||||||
@@ -36,7 +51,6 @@ public static class InvertedIndexSearcher
|
|||||||
protected virtual SearchResultToken? GetLastToken() => null; // Not on intermediate results
|
protected virtual SearchResultToken? GetLastToken() => null; // Not on intermediate results
|
||||||
protected virtual int? GetMatchRatioLevel() => null; // Not on intermediate/candidate results
|
protected virtual int? GetMatchRatioLevel() => null; // Not on intermediate/candidate results
|
||||||
protected abstract double GetMatchRatio();
|
protected abstract double GetMatchRatio();
|
||||||
protected virtual int FallbackCompareTo(T other) => 0; // Called when all other comparisons are equal
|
|
||||||
|
|
||||||
public int CompareTo(T other)
|
public int CompareTo(T other)
|
||||||
{
|
{
|
||||||
@@ -73,7 +87,7 @@ public static class InvertedIndexSearcher
|
|||||||
double aMatchRatio = GetMatchRatio(), bMatchRatio = other.GetMatchRatio();
|
double aMatchRatio = GetMatchRatio(), bMatchRatio = other.GetMatchRatio();
|
||||||
if (aMatchRatio != bMatchRatio) return bMatchRatio < aMatchRatio ? -1 : bMatchRatio > aMatchRatio ? 1 : 0;
|
if (aMatchRatio != bMatchRatio) return bMatchRatio < aMatchRatio ? -1 : bMatchRatio > aMatchRatio ? 1 : 0;
|
||||||
|
|
||||||
return FallbackCompareTo(other);
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,7 +137,7 @@ public static class InvertedIndexSearcher
|
|||||||
protected override SearchResultToken? GetLastToken() => Result.Tokens[^1];
|
protected override SearchResultToken? GetLastToken() => Result.Tokens[^1];
|
||||||
protected override double GetMatchRatio() => Result.MatchRatio;
|
protected override double GetMatchRatio() => Result.MatchRatio;
|
||||||
protected override int? GetMatchRatioLevel() => Result.MatchRatioLevel;
|
protected override int? GetMatchRatioLevel() => Result.MatchRatioLevel;
|
||||||
protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture);
|
// protected override int FallbackCompareTo(FinalResult other) => string.Compare(Result.DocumentText, other.Result.DocumentText, StringComparison.InvariantCulture);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static bool IsIgnorableCodePoint(int codePoint) => CommonUtils.IsWhitespace(codePoint) || codePoint == 0x3099 || codePoint == 0x309A;
|
private static bool IsIgnorableCodePoint(int codePoint) => CommonUtils.IsWhitespace(codePoint) || codePoint == 0x3099 || codePoint == 0x309A;
|
||||||
@@ -151,7 +165,7 @@ public static class InvertedIndexSearcher
|
|||||||
private static bool HasNonEmptyCharacters(int[] documentCodePoints, int start, int end) =>
|
private static bool HasNonEmptyCharacters(int[] documentCodePoints, int start, int end) =>
|
||||||
start != end && !documentCodePoints.Skip(start).Take(end - start).All(CommonUtils.IsWhitespace);
|
start != end && !documentCodePoints.Skip(start).Take(end - start).All(CommonUtils.IsWhitespace);
|
||||||
|
|
||||||
public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text)
|
public static SearchResult[] Search(LoadedInvertedIndex invertedIndex, string text, InvertedIndexSearcherOptions? options = null)
|
||||||
{
|
{
|
||||||
var documents = invertedIndex.Documents;
|
var documents = invertedIndex.Documents;
|
||||||
var documentCodePoints = invertedIndex.DocumentCodePoints;
|
var documentCodePoints = invertedIndex.DocumentCodePoints;
|
||||||
@@ -183,6 +197,7 @@ public static class InvertedIndexSearcher
|
|||||||
];
|
];
|
||||||
foreach (var tokenId in matchingTokenIds) foreach (var reference in tokenDefinitions[tokenId].References)
|
foreach (var tokenId in matchingTokenIds) foreach (var reference in tokenDefinitions[tokenId].References)
|
||||||
{
|
{
|
||||||
|
if (options?.FilterDocument != null && !options.FilterDocument(reference.DocumentId)) continue;
|
||||||
var isTokenPrefixMatching = !romajiNode.IsTokenExactMatch(tokenId) && !kanaNode.IsTokenExactMatch(tokenId) && !otherNode.IsTokenExactMatch(tokenId);
|
var isTokenPrefixMatching = !romajiNode.IsTokenExactMatch(tokenId) && !kanaNode.IsTokenExactMatch(tokenId) && !otherNode.IsTokenExactMatch(tokenId);
|
||||||
var previousMatchesOfDocument = l != 0 && dp[l - 1].TryGetValue(reference.DocumentId, out var previousMatches) ? previousMatches : null;
|
var previousMatchesOfDocument = l != 0 && dp[l - 1].TryGetValue(reference.DocumentId, out var previousMatches) ? previousMatches : null;
|
||||||
if (l != 0 && previousMatchesOfDocument == null) continue;
|
if (l != 0 && previousMatchesOfDocument == null) continue;
|
||||||
@@ -265,6 +280,13 @@ public static class InvertedIndexSearcher
|
|||||||
MatchRatioLevel = matchRatioLevel,
|
MatchRatioLevel = matchRatioLevel,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}).OrderBy(result => result).Select(result => result.Result).ToArray();
|
}).OrderBy(result => result, Comparer<FinalResult>.Create((a, b) =>
|
||||||
|
{
|
||||||
|
var compareResult = a.CompareTo(b);
|
||||||
|
if (compareResult != 0) return compareResult;
|
||||||
|
return options?.NextComparer == null
|
||||||
|
? string.Compare(a.Result.DocumentText, b.Result.DocumentText, StringComparison.InvariantCulture)
|
||||||
|
: options.NextComparer(a.Result.DocumentId, b.Result.DocumentId);
|
||||||
|
})).Select(result => result.Result).ToArray();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -139,3 +139,54 @@ public sealed class Search_BundleDocumentsOption_ThrowsWhenNoneProvidedTest : Ne
|
|||||||
Assert.Throws<ArgumentException>(() => InvertedIndexLoader.Load(compressed));
|
Assert.Throws<ArgumentException>(() => InvertedIndexLoader.Load(compressed));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public sealed class Search_FilterDocumentOption_ExcludesFilteredDocumentsTest : NeedleTestBase
|
||||||
|
{
|
||||||
|
private static readonly string[] TestDocuments =
|
||||||
|
[
|
||||||
|
"ミーティア",
|
||||||
|
"エンドマークに希望と涙を添えて",
|
||||||
|
"宵の鳥",
|
||||||
|
"僕の和風本当上手",
|
||||||
|
];
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Execute()
|
||||||
|
{
|
||||||
|
var compressed = InvertedIndexBuilder.BuildInvertedIndex(TestDocuments, TokenizerOptions);
|
||||||
|
var invertedIndex = InvertedIndexLoader.Load(compressed);
|
||||||
|
|
||||||
|
// Search without filter - should find "宵の鳥" (documentId 2)
|
||||||
|
var resultsWithoutFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi");
|
||||||
|
Assert.Contains("宵の鳥", resultsWithoutFilter.Select(r => r.DocumentText));
|
||||||
|
|
||||||
|
// Search with filter excluding documentId 2
|
||||||
|
var resultsWithFilter = InvertedIndexSearcher.Search(invertedIndex, "yoi", new InvertedIndexSearcherOptions
|
||||||
|
{
|
||||||
|
FilterDocument = id => id != 2
|
||||||
|
});
|
||||||
|
Assert.DoesNotContain("宵の鳥", resultsWithFilter.Select(r => r.DocumentText));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public sealed class Search_NextComparerOption_UsesCustomComparerTest : NeedleTestBase
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public void Execute()
|
||||||
|
{
|
||||||
|
// Create documents that would have similar match scores
|
||||||
|
var similarDocs = new[] { "テストA", "テストB", "テストC" };
|
||||||
|
var compressed = InvertedIndexBuilder.BuildInvertedIndex(similarDocs, TokenizerOptions);
|
||||||
|
var invertedIndex = InvertedIndexLoader.Load(compressed);
|
||||||
|
|
||||||
|
// Search with reverse order comparer
|
||||||
|
var results = InvertedIndexSearcher.Search(invertedIndex, "テスト", new InvertedIndexSearcherOptions
|
||||||
|
{
|
||||||
|
NextComparer = (a, b) => b - a // Reverse by documentId
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should be in reverse documentId order (2, 1, 0) when other criteria equal
|
||||||
|
Assert.Equal([2, 1, 0], results.Select(r => r.DocumentId).ToArray());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -100,4 +100,38 @@ describe('search options', () => {
|
|||||||
expect(() => loadInvertedIndex(compressed)).toThrow();
|
expect(() => loadInvertedIndex(compressed)).toThrow();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('filterDocument option', () => {
|
||||||
|
it('should exclude filtered documents from results', () => {
|
||||||
|
const compressed = buildInvertedIndex(testDocuments, { kuromoji });
|
||||||
|
const invertedIndex = loadInvertedIndex(compressed);
|
||||||
|
|
||||||
|
// Search without filter - should find "宵の鳥" (documentId 2)
|
||||||
|
const resultsWithoutFilter = searchInvertedIndex(invertedIndex, 'yoi');
|
||||||
|
expect(resultsWithoutFilter.map(r => r.documentText)).toContain('宵の鳥');
|
||||||
|
|
||||||
|
// Search with filter excluding documentId 2
|
||||||
|
const resultsWithFilter = searchInvertedIndex(invertedIndex, 'yoi', {
|
||||||
|
filterDocument: id => id !== 2,
|
||||||
|
});
|
||||||
|
expect(resultsWithFilter.map(r => r.documentText)).not.toContain('宵の鳥');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('nextComparer option', () => {
|
||||||
|
it('should use custom comparer for final sorting when other criteria are equal', () => {
|
||||||
|
// Create documents that would have similar match scores
|
||||||
|
const similarDocs = ['テストA', 'テストB', 'テストC'];
|
||||||
|
const compressed = buildInvertedIndex(similarDocs, { kuromoji });
|
||||||
|
const invertedIndex = loadInvertedIndex(compressed);
|
||||||
|
|
||||||
|
// Search with reverse order comparer
|
||||||
|
const results = searchInvertedIndex(invertedIndex, 'テスト', {
|
||||||
|
nextComparer: (a, b) => b - a, // Reverse by documentId
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should be in reverse documentId order (2, 1, 0) when other criteria equal
|
||||||
|
expect(results.map(r => r.documentId)).toEqual([2, 1, 0]);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -133,12 +133,24 @@ const compareFinalResult = getComparerForTraits<SearchResult>({
|
|||||||
getLastToken: state => state.tokens[state.tokens.length - 1]!,
|
getLastToken: state => state.tokens[state.tokens.length - 1]!,
|
||||||
getMatchRatio: state => state.matchRatio,
|
getMatchRatio: state => state.matchRatio,
|
||||||
getMatchRatioLevel: state => Math.round(state.matchRatio * 5),
|
getMatchRatioLevel: state => Math.round(state.matchRatio * 5),
|
||||||
nextComparer: (a, b) => a.documentText === b.documentText ? 0 : a.documentText < b.documentText ? -1 : 1,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const hasNonEmptyCharacters = (documentCodePoints: string[], start: number, end: number) => start !== end && !documentCodePoints.slice(start, end).every(char => /\s/.test(char));
|
const hasNonEmptyCharacters = (documentCodePoints: string[], start: number, end: number) => start !== end && !documentCodePoints.slice(start, end).every(char => /\s/.test(char));
|
||||||
|
|
||||||
export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: string): SearchResult[] => {
|
export const searchInvertedIndex = (
|
||||||
|
invertedIndex: LoadedInvertedIndex,
|
||||||
|
text: string,
|
||||||
|
options?: {
|
||||||
|
/**
|
||||||
|
* Called when all other comparisons are equal.
|
||||||
|
*/
|
||||||
|
nextComparer?: (documentIdA: number, documentIdB: number) => number;
|
||||||
|
/**
|
||||||
|
* If return falsy value for a document, it will be excluded from the final results.
|
||||||
|
*/
|
||||||
|
filterDocument?: (documentId: number) => unknown;
|
||||||
|
},
|
||||||
|
): SearchResult[] => {
|
||||||
const { documents, documentCodePoints, tokenDefinitions, tries } = invertedIndex;
|
const { documents, documentCodePoints, tokenDefinitions, tries } = invertedIndex;
|
||||||
|
|
||||||
const codePoints = [...toKatakana(normalizeByCodePoint(text))];
|
const codePoints = [...toKatakana(normalizeByCodePoint(text))];
|
||||||
@@ -162,6 +174,7 @@ export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: st
|
|||||||
...getTrieNodeTokenIds(otherNode, reachingInputEnd),
|
...getTrieNodeTokenIds(otherNode, reachingInputEnd),
|
||||||
]);
|
]);
|
||||||
for (const tokenId of matchingTokenIds) for (const { documentId, offsets } of tokenDefinitions[tokenId]!.references) {
|
for (const tokenId of matchingTokenIds) for (const { documentId, offsets } of tokenDefinitions[tokenId]!.references) {
|
||||||
|
if (options?.filterDocument && !options.filterDocument(documentId)) continue;
|
||||||
const isTokenPrefixMatching = !romajiNode?.tokenIds.includes(tokenId) && !kanaNode?.tokenIds.includes(tokenId) && !otherNode?.tokenIds.includes(tokenId);
|
const isTokenPrefixMatching = !romajiNode?.tokenIds.includes(tokenId) && !kanaNode?.tokenIds.includes(tokenId) && !otherNode?.tokenIds.includes(tokenId);
|
||||||
const previousMatchesOfDocument = dp[l - 1]?.get(documentId);
|
const previousMatchesOfDocument = dp[l - 1]?.get(documentId);
|
||||||
if (l !== 0 && !previousMatchesOfDocument) continue;
|
if (l !== 0 && !previousMatchesOfDocument) continue;
|
||||||
@@ -231,7 +244,13 @@ export const searchInvertedIndex = (invertedIndex: LoadedInvertedIndex, text: st
|
|||||||
matchRatio,
|
matchRatio,
|
||||||
matchRatioLevel,
|
matchRatioLevel,
|
||||||
};
|
};
|
||||||
}).sort(compareFinalResult);
|
}).sort((a, b) => {
|
||||||
|
const compareResult = compareFinalResult(a, b);
|
||||||
|
if (compareResult !== 0) return compareResult;
|
||||||
|
return options?.nextComparer
|
||||||
|
? options.nextComparer(a.documentId, b.documentId)
|
||||||
|
: a.documentText === b.documentText ? 0 : a.documentText < b.documentText ? -1 : 1;
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
// For debugging
|
// For debugging
|
||||||
|
|||||||
Reference in New Issue
Block a user