Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Numerics.Tensors;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.DataIngestion.Chunkers;

/// <summary>
/// Splits a <see cref="IngestionDocument"/> into chunks based on semantic similarity between its elements based on cosine distance of their embeddings.
/// </summary>
public sealed class SemanticSimilarityChunker : IngestionChunker<string>
{
private readonly ElementsChunker _elementsChunker;
private readonly IEmbeddingGenerator<string, Embedding<float>> _embeddingGenerator;
private readonly float _thresholdPercentile;

/// <summary>
/// Initializes a new instance of the <see cref="SemanticSimilarityChunker"/> class.
/// </summary>
/// <param name="embeddingGenerator">Embedding generator.</param>
/// <param name="options">The options for the chunker.</param>
/// <param name="thresholdPercentile">Threshold percentile to consider the chunks to be sufficiently similar. 95th percentile will be used if not specified.</param>
public SemanticSimilarityChunker(
IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator,
IngestionChunkerOptions options,
float? thresholdPercentile = null)
{
_embeddingGenerator = embeddingGenerator;
_elementsChunker = new(options);

if (thresholdPercentile < 0f || thresholdPercentile > 100f)
{
Throw.ArgumentOutOfRangeException(nameof(thresholdPercentile), "Threshold percentile must be between 0 and 100.");
}

_thresholdPercentile = thresholdPercentile ?? 95.0f;
}

/// <inheritdoc/>
public override async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IngestionDocument document,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(document);

List<(IngestionDocumentElement, float)> distances = await CalculateDistancesAsync(document, cancellationToken).ConfigureAwait(false);
foreach (var chunk in MakeChunks(document, distances))
{
yield return chunk;
}
}

private async Task<List<(IngestionDocumentElement element, float distance)>> CalculateDistancesAsync(IngestionDocument documents, CancellationToken cancellationToken)
{
List<(IngestionDocumentElement element, float distance)> elementDistances = [];
List<string> semanticContents = [];

foreach (IngestionDocumentElement element in documents.EnumerateContent())
{
string? semanticContent = element is IngestionDocumentImage img
? img.AlternativeText ?? img.Text
: element.GetMarkdown();

if (!string.IsNullOrEmpty(semanticContent))
{
elementDistances.Add((element, default));
semanticContents.Add(semanticContent!);
}
}

if (elementDistances.Count > 0)
{
var embeddings = await _embeddingGenerator.GenerateAsync(semanticContents, cancellationToken: cancellationToken).ConfigureAwait(false);

if (embeddings.Count != elementDistances.Count)
{
Throw.InvalidOperationException("The number of embeddings returned does not match the number of document elements.");
}

for (int i = 0; i < elementDistances.Count - 1; i++)
{
float distance = 1 - TensorPrimitives.CosineSimilarity(embeddings[i].Vector.Span, embeddings[i + 1].Vector.Span);
elementDistances[i] = (elementDistances[i].element, distance);
}
}

return elementDistances;
}

private IEnumerable<IngestionChunk<string>> MakeChunks(IngestionDocument document, List<(IngestionDocumentElement element, float distance)> elementDistances)
{
float distanceThreshold = Percentile(elementDistances);

List<IngestionDocumentElement> elementAccumulator = [];
string context = string.Empty;
for (int i = 0; i < elementDistances.Count; i++)
{
var (element, distance) = elementDistances[i];

elementAccumulator.Add(element);
if (distance > distanceThreshold || i == elementDistances.Count - 1)
{
foreach (var chunk in _elementsChunker.Process(document, context, elementAccumulator))
{
yield return chunk;
}
elementAccumulator.Clear();
}
}
}

private float Percentile(List<(IngestionDocumentElement element, float distance)> elementDistances)
{
if (elementDistances.Count == 0)
{
return 0f;
}
else if (elementDistances.Count == 1)
{
return elementDistances[0].distance;
}

float[] sorted = new float[elementDistances.Count];
for (int elementIndex = 0; elementIndex < elementDistances.Count; elementIndex++)
{
sorted[elementIndex] = elementDistances[elementIndex].distance;
}
Array.Sort(sorted);

float i = (_thresholdPercentile / 100f) * (sorted.Length - 1);
int i0 = (int)i;
int i1 = Math.Min(i0 + 1, sorted.Length - 1);
return sorted[i0] + ((i - i0) * (sorted[i1] - sorted[i0]));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<!-- we are not ready to publish yet -->
<IsPackable>false</IsPackable>
<Stage>preview</Stage>
<EnablePackageValidation>false</EnablePackageValidation>
<EnablePackageValidation>false</EnablePackageValidation>
</PropertyGroup>

<ItemGroup>
Expand All @@ -21,6 +21,7 @@
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="Microsoft.Extensions.VectorData.Abstractions" />
<PackageReference Include="Microsoft.ML.Tokenizers" />
<PackageReference Include="System.Numerics.Tensors" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Xunit;

namespace Microsoft.Extensions.DataIngestion.Chunkers.Tests
{
public abstract class DocumentChunkerTests
{
protected abstract IngestionChunker<string> CreateDocumentChunker(int maxTokensPerChunk = 2_000, int overlapTokens = 500);

[Fact]
public async Task ProcessAsync_ThrowsArgumentNullException_WhenDocumentIsNull()
{
var chunker = CreateDocumentChunker();
await Assert.ThrowsAsync<ArgumentNullException>("document", async () => await chunker.ProcessAsync(null!).ToListAsync());
}

[Fact]
public async Task EmptyDocument()
{
IngestionDocument emptyDoc = new("emptyDoc");
IngestionChunker<string> chunker = CreateDocumentChunker();

IReadOnlyList<IngestionChunk<string>> chunks = await chunker.ProcessAsync(emptyDoc).ToListAsync();
Assert.Empty(chunks);
}
}
}
Loading
Loading