Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions core/src/services/memcached/MIT-ascii.txt

This file was deleted.

171 changes: 99 additions & 72 deletions core/src/services/memcached/ascii.rs
Original file line number Diff line number Diff line change
@@ -1,145 +1,172 @@
// Copyright 2017 vavrusa <marek@vavrusa.com>
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// Licensed under the MIT License (see MIT-ascii.txt);

use core::fmt::Display;
use std::io::Error;
use std::io::ErrorKind;
use std::marker::Unpin;

use futures::io::AsyncBufReadExt;
use futures::io::AsyncRead;
use futures::io::AsyncReadExt;
use futures::io::AsyncWrite;
use futures::io::AsyncWriteExt;
use futures::io::BufReader;

/// Memcache ASCII protocol implementation.
pub struct Protocol<S> {
io: BufReader<S>,
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::*;

use super::backend::parse_io_error;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::net::TcpStream;

pub struct Connection {
io: BufReader<TcpStream>,
buf: Vec<u8>,
}

impl<S> Protocol<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
/// Creates the ASCII protocol on a stream.
pub fn new(io: S) -> Self {
impl Connection {
pub fn new(io: TcpStream) -> Self {
Self {
io: BufReader::new(io),
buf: Vec::new(),
}
}

/// Returns the value for given key as bytes. If the value doesn't exist, [`ErrorKind::NotFound`] is returned.
pub async fn get<K: AsRef<[u8]>>(&mut self, key: K) -> Result<Vec<u8>, Error> {
pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
// Send command
let writer = self.io.get_mut();
writer
.write_all(&[b"get ", key.as_ref(), b"\r\n"].concat())
.await?;
writer.flush().await?;
.write_all(&[b"get ", key.as_bytes(), b"\r\n"].concat())
.await
.map_err(parse_io_error)?;
writer.flush().await.map_err(parse_io_error)?;

// Read response header
let header = self.read_line().await?;
let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?;
let header = self.read_header().await?;

// Check response header and parse value length
if header.contains("ERROR") {
return Err(Error::new(ErrorKind::Other, header));
return Err(
Error::new(ErrorKind::Unexpected, "unexpected data received")
.with_context("message", header),
);
} else if header.starts_with("END") {
return Err(ErrorKind::NotFound.into());
return Ok(None);
}

// VALUE <key> <flags> <bytes> [<cas unique>]\r\n
let length: usize = header
.split(' ')
.nth(3)
.and_then(|len| len.trim_end().parse().ok())
.ok_or(ErrorKind::InvalidData)?;
.ok_or_else(|| Error::new(ErrorKind::Unexpected, "invalid data received"))?;

// Read value
let mut buffer: Vec<u8> = vec![0; length];
self.io.read_exact(&mut buffer).await?;
self.io
.read_exact(&mut buffer)
.await
.map_err(parse_io_error)?;

// Read the trailing header
self.read_line().await?; // \r\n
self.read_line().await?; // END\r\n

Ok(buffer)
Ok(Some(buffer))
}

/// Set key to given value and don't wait for response.
pub async fn set<K: Display>(
&mut self,
key: K,
val: &[u8],
expiration: u32,
) -> Result<(), Error> {
pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> {
let header = format!("set {} 0 {} {}\r\n", key, expiration, val.len());
self.io.write_all(header.as_bytes()).await?;
self.io.write_all(val).await?;
self.io.write_all(b"\r\n").await?;
self.io.flush().await?;
self.io
.write_all(header.as_bytes())
.await
.map_err(parse_io_error)?;
self.io.write_all(val).await.map_err(parse_io_error)?;
self.io.write_all(b"\r\n").await.map_err(parse_io_error)?;
self.io.flush().await.map_err(parse_io_error)?;

// Read response header
let header = self.read_line().await?;
let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?;
let header = self.read_header().await?;

// Check response header and make sure we got a `STORED`
if header.contains("STORED") {
return Ok(());
} else if header.contains("ERROR") {
return Err(Error::new(ErrorKind::Other, header));
return Err(
Error::new(ErrorKind::Unexpected, "unexpected data received")
.with_context("message", header),
);
}
Ok(())
}

/// Delete a key and don't wait for response.
pub async fn delete<K: Display>(&mut self, key: K) -> Result<(), Error> {
pub async fn delete(&mut self, key: &str) -> Result<()> {
let header = format!("delete {}\r\n", key);
self.io.write_all(header.as_bytes()).await?;
self.io.flush().await?;
self.io
.write_all(header.as_bytes())
.await
.map_err(parse_io_error)?;
self.io.flush().await.map_err(parse_io_error)?;

// Read response header
let header = self.read_line().await?;
let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?;
let header = self.read_header().await?;

// Check response header and parse value length
if header.contains("NOT_FOUND") {
if header.contains("NOT_FOUND") || header.starts_with("END") {
return Ok(());
} else if header.starts_with("END") {
return Err(ErrorKind::NotFound.into());
} else if header.contains("ERROR") || !header.contains("DELETED") {
return Err(Error::new(ErrorKind::Other, header));
return Err(
Error::new(ErrorKind::Unexpected, "unexpected data received")
.with_context("message", header),
);
}
Ok(())
}

/// Return the version of the remote server.
pub async fn version(&mut self) -> Result<String, Error> {
self.io.write_all(b"version\r\n").await?;
self.io.flush().await?;
pub async fn version(&mut self) -> Result<String> {
self.io
.write_all(b"version\r\n")
.await
.map_err(parse_io_error)?;
self.io.flush().await.map_err(parse_io_error)?;

// Read response header
let header = {
let buf = self.read_line().await?;
std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
};
let header = self.read_header().await?;

if !header.starts_with("VERSION") {
return Err(Error::new(ErrorKind::Other, header));
return Err(
Error::new(ErrorKind::Unexpected, "unexpected data received")
.with_context("message", header),
);
}
let version = header.trim_start_matches("VERSION ").trim_end();
Ok(version.to_string())
}

async fn read_line(&mut self) -> Result<&[u8], Error> {
async fn read_line(&mut self) -> Result<&[u8]> {
let Self { io, buf } = self;
buf.clear();
io.read_until(b'\n', buf).await?;
io.read_until(b'\n', buf).await.map_err(parse_io_error)?;
if buf.last().copied() != Some(b'\n') {
return Err(ErrorKind::UnexpectedEof.into());
return Err(Error::new(
ErrorKind::ContentIncomplete,
"unexpected eof, the response must be incomplete",
));
}
Ok(&buf[..])
}

async fn read_header(&mut self) -> Result<&str> {
let header = self.read_line().await?;
let header = std::str::from_utf8(header).map_err(|err| {
Error::new(ErrorKind::Unexpected, "invalid data received").set_source(err)
})?;

Ok(header)
}
}
54 changes: 14 additions & 40 deletions core/src/services/memcached/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use std::collections::HashMap;
use std::time::Duration;

use async_compat::Compat;
use async_trait::async_trait;
use bb8::RunError;
use tokio::net::TcpStream;
Expand Down Expand Up @@ -220,7 +219,7 @@ impl Adapter {
RunError::TimedOut => {
Error::new(ErrorKind::Unexpected, "get connection from pool failed").set_temporary()
}
RunError::User(err) => parse_io_error(err),
RunError::User(err) => err,
})
}
}
Expand All @@ -243,12 +242,8 @@ impl kv::Adapter for Adapter {

async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
let mut conn = self.conn().await?;
// TODO: memcache-async have `Sized` limit on key, can we remove it?
match conn.get(&percent_encode_path(key)).await {
Ok(bs) => Ok(Some(bs)),
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(parse_io_error(err)),
}

conn.get(&percent_encode_path(key)).await
}

async fn set(&self, key: &str, value: &[u8]) -> Result<()> {
Expand All @@ -263,40 +258,13 @@ impl kv::Adapter for Adapter {
.unwrap_or_default(),
)
.await
.map_err(parse_io_error)?;

Ok(())
}

async fn delete(&self, key: &str) -> Result<()> {
let mut conn = self.conn().await?;

let _: () = conn
.delete(&percent_encode_path(key))
.await
.map_err(parse_io_error)?;
Ok(())
}
}

fn parse_io_error(err: std::io::Error) -> Error {
use std::io::ErrorKind::*;

let (kind, retryable) = match err.kind() {
NotFound => (ErrorKind::NotFound, false),
AlreadyExists => (ErrorKind::NotFound, false),
PermissionDenied => (ErrorKind::PermissionDenied, false),
Interrupted | UnexpectedEof | TimedOut | WouldBlock => (ErrorKind::Unexpected, true),
_ => (ErrorKind::Unexpected, true),
};

let mut err = Error::new(kind, &err.kind().to_string()).set_source(err);

if retryable {
err = err.set_temporary();
conn.delete(&percent_encode_path(key)).await
}

err
}

/// A `bb8::ManageConnection` for `memcache_async::ascii::Protocol`.
Expand All @@ -317,13 +285,15 @@ impl MemcacheConnectionManager {

#[async_trait]
impl bb8::ManageConnection for MemcacheConnectionManager {
type Connection = ascii::Protocol<Compat<TcpStream>>;
type Error = std::io::Error;
type Connection = ascii::Connection;
type Error = Error;

/// TODO: Implement unix stream support.
async fn connect(&self) -> std::result::Result<Self::Connection, Self::Error> {
let sock = TcpStream::connect(&self.address).await?;
Ok(ascii::Protocol::new(Compat::new(sock)))
let conn = TcpStream::connect(&self.address)
.await
.map_err(parse_io_error)?;
Ok(ascii::Connection::new(conn))
}

async fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> {
Expand All @@ -334,3 +304,7 @@ impl bb8::ManageConnection for MemcacheConnectionManager {
false
}
}

pub fn parse_io_error(err: std::io::Error) -> Error {
Error::new(ErrorKind::Unexpected, &err.kind().to_string()).set_source(err)
}