diff --git a/models/segment-text/coreml/README.md b/models/segment-text/coreml/README.md index a2a1ecc..d4ba2f8 100644 --- a/models/segment-text/coreml/README.md +++ b/models/segment-text/coreml/README.md @@ -3,6 +3,10 @@ Segment Any Text is state-of-the-art sentence segmentation with 3 Transfomer lay If you wish to skip the CoreML conversion, you can download a precompiled `SaT.mlmodelc` from [Hugging Face](https://huggingface.co/smdesai/SaT). +## Swift Usage + +Swift sample code for testing and integrating the Core ML model is available at [smdesai/SegmentText](https://github.com/smdesai/SegmentText). + # CoreML Conversion @@ -51,7 +55,7 @@ Usage: convert_sat.py [OPTIONS] Run the following to compile the model. ```bash -python compile_mlmodelc.py --coreml-dir sat_coreml +python compile_mlmodelc.py --coreml-dir sat_coreml --output-dir compiled ``` This produces `SaT.mlmodelc` in the `compiled` directory. @@ -61,6 +65,8 @@ Here is the complete usage: Usage: compile_mlmodelc.py [OPTIONS] Options - --coreml-dir PATH Directory where mlpackages and metadata are written - [default: sat_coreml] + --coreml-dir PATH Directory where the mlpackage is + [default: sat_coreml] + --output-dir PATH Directory where the compiled model is written + [default: compiled] ``` diff --git a/models/segment-text/coreml/compile_mlmodelc.py b/models/segment-text/coreml/compile_mlmodelc.py index 884c654..8edea79 100644 --- a/models/segment-text/coreml/compile_mlmodelc.py +++ b/models/segment-text/coreml/compile_mlmodelc.py @@ -41,13 +41,12 @@ def gather_packages(dir: str) -> list[Path]: return packages -def compile_package(package: Path) -> None: +def compile_package(package: Path, output_dir: Path) -> None: """Compile a single ``.mlpackage`` bundle using ``xcrun coremlcompiler``.""" relative_pkg = package.relative_to(BASE_DIR) - #output_dir = OUTPUT_ROOT / relative_pkg.parent - output_dir = OUTPUT_ROOT - output_dir.mkdir(parents=True, exist_ok=True) - output_path = output_dir / f"{package.stem}.mlmodelc" + resolved_output_dir = output_dir if output_dir.is_absolute() else BASE_DIR / output_dir + resolved_output_dir.mkdir(parents=True, exist_ok=True) + output_path = resolved_output_dir / f"{package.stem}.mlmodelc" if output_path.exists(): shutil.rmtree(output_path) @@ -57,10 +56,15 @@ def compile_package(package: Path) -> None: "coremlcompiler", "compile", str(package), - str(output_dir), + str(resolved_output_dir), ] - print(f"Compiling {relative_pkg} -> {output_path.relative_to(BASE_DIR)}") + try: + relative_output = output_path.relative_to(BASE_DIR) + except ValueError: + relative_output = output_path + + print(f"Compiling {relative_pkg} -> {relative_output}") subprocess.run(cmd, check=True) @@ -68,7 +72,11 @@ def compile_package(package: Path) -> None: def compile( coreml_dir: Path = typer.Option( Path("sat_coreml"), - help="Directory where mlpackages and metadata are written", + help="Directory where the mlpackage is", + ), + output_dir: Path = typer.Option( + Path("compiled"), + help="Directory where the compiled model is written", ), ): ensure_coremlcompiler() @@ -80,7 +88,7 @@ def compile( for package in packages: try: - compile_package(package) + compile_package(package, output_dir) except subprocess.CalledProcessError as exc: print(f"Failed to compile {package}: {exc}", file=sys.stderr) sys.exit(exc.returncode) diff --git a/models/segment-text/coreml/convert_sat.py b/models/segment-text/coreml/convert_sat.py index 59ad399..62dca0a 100644 --- a/models/segment-text/coreml/convert_sat.py +++ b/models/segment-text/coreml/convert_sat.py @@ -83,7 +83,24 @@ def convert( return_dict=False, torchscript=True, trust_remote_code=True, - ).eval() + ).eval().to("cpu") + + class WrappedModel(torch.nn.Module): + def __init__(self, base_model: torch.nn.Module): + super().__init__() + self.base_model = base_model + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + input_ids = input_ids.to(dtype=torch.long) + attention_mask = attention_mask.to(dtype=torch.long) + outputs = self.base_model(input_ids, attention_mask) + if isinstance(outputs, (tuple, list)): + logits = outputs[0] + else: + logits = outputs + return logits.to(dtype=torch.float32) + + wrapped_model = WrappedModel(model).eval() tokenizer = AutoTokenizer.from_pretrained("facebookAI/xlm-roberta-base") tokenized = tokenizer( @@ -93,12 +110,17 @@ def convert( padding="max_length", ) - traced_model = torch.jit.trace( - model, - (tokenized["input_ids"], tokenized["attention_mask"]) + example_inputs = ( + tokenized["input_ids"].to(torch.int32), + tokenized["attention_mask"].to(torch.int32), ) + traced_model = torch.jit.trace(wrapped_model, example_inputs, strict=False) + traced_model.eval() + + with torch.no_grad(): + sample_output = wrapped_model(*example_inputs) - outputs = [ct.TensorType(name="output")] + output_spec = ct.TensorType(name="logits", dtype=np.float32) mlpackage = ct.convert( traced_model, @@ -111,8 +133,9 @@ def convert( ) for name, tensor in tokenized.items() ], - outputs=outputs, + outputs=[output_spec], compute_units=ct.ComputeUnit.ALL, + compute_precision=ct.precision.FLOAT32, minimum_deployment_target=ct.target.iOS18, )