Skip to content
Open
91 changes: 91 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,97 @@ Downloads a specified encrypted file, decrypts it and then behaves identically t
The request body for this route is the same as for
`POST /_matrix/media_proxy/unstable/download_encrypted`.

### `POST /_matrix/media_proxy/unstable/scan_file`

Scans a file directly without downloading it from a Matrix homeserver. The file
content is sent in the request body as a `multipart/form-data` upload.

#### Request

The request must use `Content-Type: multipart/form-data` with the following parts:

| Part name | Required | Type | Description |
|-----------|----------|-a-----|-------------|
| `body` | **Yes** | Binary (file content) | The raw file to scan. |
| `file` | No | JSON string | Decryption metadata for an encrypted file. Follows the [`EncryptedFile`](https://spec.matrix.org/v1.2/client-server-api/#extensions-to-mroommessage-msgtypes) structure from the Matrix specification. Only needed when the file in `body` is encrypted. |

#### Request examples

Scan an unencrypted file with `curl`:

```bash
curl -X POST \
http://localhost:8080/_matrix/media_proxy/unstable/scan_file \
-F "body=@document.pdf;type=application/pdf"
```

Scan an encrypted file (provide decryption metadata via the `file` part):

```bash
curl -X POST \
http://localhost:8080/_matrix/media_proxy/unstable/scan_file \
-F "body=@encrypted_file.bin;type=application/octet-stream" \
-F "file={\"v\":\"v2\",\"key\":{...},\"iv\":\"...\",\"hashes\":{...}};type=application/json"
```

Scan a file with Python (`requests`):

```python
import requests

resp = requests.post(
"http://localhost:8080/_matrix/media_proxy/unstable/scan_file",
files={"body": ("image.png", open("image.png", "rb"), "image/png")},
)
print(resp.json()) # {"clean": true, "info": "File is clean"}
```

Scan an encrypted file with Python (`requests`), providing decryption metadata via the `file` part:

```python
import json
import requests

encrypted_file_metadata = {
"v": "v2",
"key": {
"alg": "A256CTR",
"ext": True,
"k": "base64-encoded-key",
"key_ops": ["encrypt", "decrypt"],
"kty": "oct",
},
"iv": "base64-encoded-iv",
"hashes": {
"sha256": "base64-encoded-hash",
},
}

resp = requests.post(
"http://localhost:8080/_matrix/media_proxy/unstable/scan_file",
files={
"body": ("encrypted.bin", open("encrypted.bin", "rb"), "application/octet-stream"),
"file": ("metadata.json", json.dumps(encrypted_file_metadata), "application/json"),
},
)
print(resp.json()) # {"clean": true, "info": "File is clean"}
```

#### Response

| Parameter | Type | Description |
|-----------|------|-------------|
| `clean` | bool | `true` if the file passed the scan, `false` otherwise. |
| `info` | str | Human-readable result description. |

Example response:

```json
{
"clean": false,
"info": "***VIRUS DETECTED***"
}
```

### `GET /_matrix/media_proxy/unstable/public_key`

Expand Down
52 changes: 51 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ humanfriendly = ">=10.0"
# Required for calculating cache keys deterministically. Type annotations aren't
# discoverable in versions older than 1.6.3.
canonicaljson = ">=1.6.3"
# Required for non-blocking file I/O.
aiofile = ">=3.8.0"
setuptools_rust = ">=1.3"

[tool.poetry.dev-dependencies]
Expand Down
1 change: 1 addition & 0 deletions src/matrix_content_scanner/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _build_app(self) -> web.Application:
[
web.get("/scan" + _MEDIA_PATH_REGEXP, scan_handler.handle_plain),
web.post("/scan_encrypted", scan_handler.handle_encrypted),
web.post("/scan_file", scan_handler.handle_file),
web.get(
"/download" + _MEDIA_PATH_REGEXP, download_handler.handle_plain
),
Expand Down
118 changes: 88 additions & 30 deletions src/matrix_content_scanner/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import logging
import os
import subprocess
import uuid
from asyncio import Future
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import attr
import magic
from aiofile import async_open
from cachetools import TTLCache
from canonicaljson import encode_canonical_json
from humanfriendly import format_size
Expand Down Expand Up @@ -320,6 +322,76 @@ async def _scan_file(

return media

async def scan_content(
self, content: bytes, metadata: Optional[JsonDict] = None
) -> None:
"""Scan raw file bytes. The content is written to disk once (decrypted if
needed), scanned, and cleaned up.

This does not use the result cache or concurrent-request deduplication.

Args:
content: The raw file bytes (possibly still encrypted).
metadata: The metadata attached to the file (e.g. decryption key), or None
if the file isn't encrypted.

Raises:
FileDirtyError if the result of the scan said that the file is dirty.
"""
exit_code = await self._do_scan(content, metadata)
result = exit_code == 0

cacheable = exit_code not in self._exit_codes_to_ignore

if result is False:
raise FileDirtyError(cacheable=cacheable)

async def _do_scan(
self,
content: bytes,
metadata: Optional[JsonDict] = None,
file_id: Optional[str] = None,
) -> int:
"""Core scan pipeline shared by all request paths.

Handles: decrypt (if needed) → write to disk → mimetype check → scan → cleanup.

Args:
content: The raw file bytes (encrypted or plaintext).
metadata: Decryption metadata, or None if the file is unencrypted.
file_id: Identifier used as the temp filename on disk. If None, a random
UUID is generated. Passing the media_path (server_name/media_id)
preserves the original directory structure for traceability.

Returns:
The exit code from the scan script (0 = clean).
"""
# Decrypt the content if necessary.
if metadata is not None:
# If the file is encrypted, we need to decrypt it before we can scan it.
content = self._decrypt_file(content, metadata)

# Write the file to disk.
file_path = await self._write_file_to_disk(
file_id or str(uuid.uuid4()), content
)

try:
# Check the file's MIME type to see if it's allowed.
self._check_mimetype(file_path)
# Scan the file and see if the result is positive or negative.
exit_code = await self._run_scan(file_path)
# Log the result of the scan.
logger.info("Scan has finished")
finally:
# This could be own function.
logger.info("Removing file")
removal_command_parts = self._removal_command.split()
removal_command_parts.append(file_path)
subprocess.run(removal_command_parts)

return exit_code

async def _scan_media(
self,
media: MediaDescription,
Expand All @@ -344,21 +416,7 @@ async def _scan_media(
FileDirtyError if the result of the scan said that the file is dirty, or if
the media path is malformed.
"""

# Decrypt the content if necessary.
media_content = media.content
if metadata is not None:
# If the file is encrypted, we need to decrypt it before we can scan it.
media_content = self._decrypt_file(media_content, metadata)

# Check the file's MIME type to see if it's allowed.
self._check_mimetype(media_content)

# Write the file to disk.
file_path = self._write_file_to_disk(media_path, media_content)

# Scan the file and see if the result is positive or negative.
exit_code = await self._run_scan(file_path)
exit_code = await self._do_scan(media.content, metadata, file_id=media_path)
result = exit_code == 0

# If the exit code isn't part of the ones we should ignore, cache the result.
Expand All @@ -369,13 +427,6 @@ async def _scan_media(
)
cacheable = False

# Delete the file now that we've scanned it.
logger.info("Scan has finished, removing file")
removal_command_parts = self._removal_command.split()
removal_command_parts.append(file_path)
subprocess.run(removal_command_parts)

# Raise an error if the result isn't clean.
if result is False:
raise FileDirtyError(cacheable=cacheable)

Expand Down Expand Up @@ -445,7 +496,7 @@ def _decrypt_file(self, body: bytes, metadata: JsonDict) -> bytes:
info=str(e),
)

def _write_file_to_disk(self, media_path: str, body: bytes) -> str:
async def _write_file_to_disk(self, media_path: str, body: bytes) -> str:
"""Writes the given content to disk. The final file name will be a concatenation
of `temp_directory` and the media's `server_name/media_id` path.

Expand Down Expand Up @@ -475,8 +526,16 @@ def _write_file_to_disk(self, media_path: str, body: bytes) -> str:
# Create any directory we need.
os.makedirs(full_path.parent, exist_ok=True)

with open(full_path, "wb") as fp:
fp.write(body)
try:
async with async_open(full_path, "wb") as fp:
await fp.write(body if isinstance(body, bytes) else bytes(body))
except Exception:
# Delete the file if the write fails.
try:
os.unlink(full_path)
except OSError:
pass
raise

return str(full_path)

Expand Down Expand Up @@ -506,16 +565,15 @@ async def _run_scan(self, file_name: str) -> int:

return retcode

def _check_mimetype(self, media_content: bytes) -> None:
"""Detects the MIME type of the provided bytes, and checks that this type is allowed
def _check_mimetype(self, filepath: str) -> None:
"""Detects the MIME type of the provided file, and checks that this type is allowed
(if an allow list is provided in the configuration)
Args:
media_content: The file's content. If the file is encrypted, this is its
decrypted content.
filepath: The full file path.
Raises:
FileMimeTypeForbiddenError if one of the checks fail.
"""
detected_mimetype = magic.from_buffer(media_content, mime=True)
detected_mimetype = magic.from_file(filepath, mime=True)
logger.debug("Detected MIME type for file is %s", detected_mimetype)

# If there's an allow list for MIME types, check that the MIME type that's been
Expand Down
26 changes: 26 additions & 0 deletions src/matrix_content_scanner/servlets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,32 @@ async def get_media_metadata_from_request(
return media_path, metadata


async def get_media_metadata_from_filebody(
file_body: JsonDict,
crypto_handler: crypto.CryptoHandler,
) -> JsonDict:
"""Extracts, optionally decrypts, and validates encrypted file metadata from a
request body.

Args:
request: The request to extract the data from.
crypto_handler: The crypto handler to use if we need to decrypt an Olm-encrypted
body.

Raises:
ContentScannerRestError(400) if the request's body is None or if the metadata
didn't pass schema validation.
"""
metadata = _metadata_from_body(file_body, crypto_handler)

validate_encrypted_file_metadata(metadata)

# Unlike get_media_metadata_from_request, we intentionally skip extracting
# the file URL from the metadata because the caller already has the media content.

return metadata


def _metadata_from_body(
body: JsonDict, crypto_handler: crypto.CryptoHandler
) -> JsonDict:
Expand Down
Loading
Loading