Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions models/segment-text/coreml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ Segment Any Text is state-of-the-art sentence segmentation with 3 Transfomer lay

If you wish to skip the CoreML conversion, you can download a precompiled `SaT.mlmodelc` from [Hugging Face](https://huggingface.co/smdesai/SaT).

## Swift Usage

Swift sample code for testing and integrating the Core ML model is available at [smdesai/SegmentText](https://github.com/smdesai/SegmentText).


# CoreML Conversion

Expand Down Expand Up @@ -51,7 +55,7 @@ Usage: convert_sat.py [OPTIONS]

Run the following to compile the model.
```bash
python compile_mlmodelc.py --coreml-dir sat_coreml
python compile_mlmodelc.py --coreml-dir sat_coreml --output-dir compiled
```

This produces `SaT.mlmodelc` in the `compiled` directory.
Expand All @@ -61,6 +65,8 @@ Here is the complete usage:
Usage: compile_mlmodelc.py [OPTIONS]

Options
--coreml-dir PATH Directory where mlpackages and metadata are written
[default: sat_coreml]
--coreml-dir PATH Directory where the mlpackage is
[default: sat_coreml]
--output-dir PATH Directory where the compiled model is written
[default: compiled]
```
26 changes: 17 additions & 9 deletions models/segment-text/coreml/compile_mlmodelc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ def gather_packages(dir: str) -> list[Path]:
return packages


def compile_package(package: Path) -> None:
def compile_package(package: Path, output_dir: Path) -> None:
"""Compile a single ``.mlpackage`` bundle using ``xcrun coremlcompiler``."""
relative_pkg = package.relative_to(BASE_DIR)
#output_dir = OUTPUT_ROOT / relative_pkg.parent
output_dir = OUTPUT_ROOT
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / f"{package.stem}.mlmodelc"
resolved_output_dir = output_dir if output_dir.is_absolute() else BASE_DIR / output_dir
resolved_output_dir.mkdir(parents=True, exist_ok=True)
output_path = resolved_output_dir / f"{package.stem}.mlmodelc"

if output_path.exists():
shutil.rmtree(output_path)
Expand All @@ -57,18 +56,27 @@ def compile_package(package: Path) -> None:
"coremlcompiler",
"compile",
str(package),
str(output_dir),
str(resolved_output_dir),
]

print(f"Compiling {relative_pkg} -> {output_path.relative_to(BASE_DIR)}")
try:
relative_output = output_path.relative_to(BASE_DIR)
except ValueError:
relative_output = output_path

print(f"Compiling {relative_pkg} -> {relative_output}")
subprocess.run(cmd, check=True)


@app.command()
def compile(
coreml_dir: Path = typer.Option(
Path("sat_coreml"),
help="Directory where mlpackages and metadata are written",
help="Directory where the mlpackage is",
),
output_dir: Path = typer.Option(
Path("compiled"),
help="Directory where the compiled model is written",
),
):
ensure_coremlcompiler()
Expand All @@ -80,7 +88,7 @@ def compile(

for package in packages:
try:
compile_package(package)
compile_package(package, output_dir)
except subprocess.CalledProcessError as exc:
print(f"Failed to compile {package}: {exc}", file=sys.stderr)
sys.exit(exc.returncode)
Expand Down
35 changes: 29 additions & 6 deletions models/segment-text/coreml/convert_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,24 @@ def convert(
return_dict=False,
torchscript=True,
trust_remote_code=True,
).eval()
).eval().to("cpu")

class WrappedModel(torch.nn.Module):
def __init__(self, base_model: torch.nn.Module):
super().__init__()
self.base_model = base_model

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
input_ids = input_ids.to(dtype=torch.long)
attention_mask = attention_mask.to(dtype=torch.long)
outputs = self.base_model(input_ids, attention_mask)
if isinstance(outputs, (tuple, list)):
logits = outputs[0]
else:
logits = outputs
return logits.to(dtype=torch.float32)

wrapped_model = WrappedModel(model).eval()

tokenizer = AutoTokenizer.from_pretrained("facebookAI/xlm-roberta-base")
tokenized = tokenizer(
Expand All @@ -93,12 +110,17 @@ def convert(
padding="max_length",
)

traced_model = torch.jit.trace(
model,
(tokenized["input_ids"], tokenized["attention_mask"])
example_inputs = (
tokenized["input_ids"].to(torch.int32),
tokenized["attention_mask"].to(torch.int32),
)
traced_model = torch.jit.trace(wrapped_model, example_inputs, strict=False)
traced_model.eval()

with torch.no_grad():
sample_output = wrapped_model(*example_inputs)

outputs = [ct.TensorType(name="output")]
output_spec = ct.TensorType(name="logits", dtype=np.float32)

mlpackage = ct.convert(
traced_model,
Expand All @@ -111,8 +133,9 @@ def convert(
)
for name, tensor in tokenized.items()
],
outputs=outputs,
outputs=[output_spec],
compute_units=ct.ComputeUnit.ALL,
compute_precision=ct.precision.FLOAT32,
minimum_deployment_target=ct.target.iOS18,
)

Expand Down