From 6f5533d7d1754122429bb72431100234557539e2 Mon Sep 17 00:00:00 2001 From: taobo Date: Mon, 9 Oct 2023 22:32:06 +0800 Subject: [PATCH 1/7] feat(service/d1): Support d1 for opendal --- .env.example | 5 + core/Cargo.toml | 2 + core/src/services/d1/backend.rs | 263 ++++++++++++++++++++++++++++++++ core/src/services/d1/docs.md | 48 ++++++ core/src/services/d1/error.rs | 55 +++++++ core/src/services/d1/mod.rs | 20 +++ core/src/services/mod.rs | 4 + core/src/types/scheme.rs | 4 + core/tests/behavior/main.rs | 2 + 9 files changed, 403 insertions(+) create mode 100644 core/src/services/d1/backend.rs create mode 100644 core/src/services/d1/docs.md create mode 100644 core/src/services/d1/error.rs create mode 100644 core/src/services/d1/mod.rs diff --git a/.env.example b/.env.example index 79211cb97a46..f82a58b5ebf6 100644 --- a/.env.example +++ b/.env.example @@ -178,3 +178,8 @@ OPENDAL_SQLITE_CONNECTION_STRING=file:///tmp/opendal/test.db OPENDAL_SQLITE_TABLE=data OPENDAL_SQLITE_KEY_FIELD=key OPENDAL_SQLITE_VALUE_FIELD=data +# d1 +OPENDAL_D1_TEST=false +OPENDAL_D1_SQL= +OPENDAL_D1_PARAMS= +OPENDAL_D1_TOKEN= \ No newline at end of file diff --git a/core/Cargo.toml b/core/Cargo.toml index c5ca1b20cd75..1cdc281078d0 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -50,6 +50,7 @@ default = [ "services-s3", "services-webdav", "services-webhdfs", + "services-d1", ] # Build docs or not. @@ -120,6 +121,7 @@ services-cos = [ "reqsign?/services-tencent", "reqsign?/reqwest_request", ] +services-d1 = [] services-dashmap = ["dep:dashmap"] services-dropbox = [] services-etcd = ["dep:etcd-client", "dep:bb8"] diff --git a/core/src/services/d1/backend.rs b/core/src/services/d1/backend.rs new file mode 100644 index 000000000000..14b8d9c39d28 --- /dev/null +++ b/core/src/services/d1/backend.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::fmt::Formatter; + +use async_trait::async_trait; +use http::header; +use http::Request; +use http::StatusCode; + +use crate::raw::adapters::kv; +use crate::raw::*; +use crate::*; + +use super::error::parse_error; + +#[doc = include_str!("docs.md")] +#[derive(Default)] +pub struct D1Builder { + root: Option, + endpoint: Option, + sql: Option, + params: Option>, + token: Option, + http_client: Option, +} + +impl Debug for D1Builder { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut ds = f.debug_struct("D1Builder"); + ds.field("endpoint", &self.endpoint); + ds.field("sql", &self.sql); + ds.field("params", &self.params); + ds.field("root", &self.root); + ds.finish() + } +} + +impl D1Builder { + /// Set endpoint for http backend. + /// + /// For more information, please refer to [D1 Database API](https://developers.cloudflare.com/api/operations/cloudflare-d1-query-database) + /// default: "https://api.cloudflare.com/client/v4" + pub fn endpoint(&mut self, v: &str) -> &mut Self { + if !v.is_empty() { + self.endpoint = Some(v.trim_end_matches('/').to_string()); + } + self + } + + /// set the working directory, all operations will be performed under it. + /// + /// default: "/" + pub fn root(&mut self, root: &str) -> &mut Self { + if !root.is_empty() { + self.root = Some(root.to_owned()); + } + self + } + + /// Set D1 execution sql. + pub fn sql(&mut self, sql: &str) -> &mut Self { + if !sql.is_empty() { + self.sql = Some(sql.to_string()); + } + self + } + + /// Set the sql value field of the d1 service. + /// + /// default: vec![] + pub fn params(&mut self, params: Vec) -> &mut Self { + if !params.is_empty() { + self.params = Some(params); + } + self + } + + /// Set the bearer token for the d1 service. + /// create a bearer token from [here](https://dash.cloudflare.com/profile/api-tokens) + pub fn token(&mut self, token: &str) -> &mut Self { + if !token.is_empty() { + self.token = Some(token.to_string()); + } + self + } +} + +impl Builder for D1Builder { + const SCHEME: Scheme = Scheme::D1; + type Accessor = D1Backend; + + fn from_map(map: HashMap) -> Self { + let mut builder = D1Builder::default(); + map.get("endpoint").map(|v| builder.endpoint(v)); + map.get("sql").map(|v| builder.sql(v)); + map.get("params") + .map(|v| builder.params(v.split(",").map(|s| s.to_string()).collect())); + map.get("root").map(|v| builder.root(v)); + map.get("token").map(|v| builder.token(v)); + builder + } + + fn build(&mut self) -> Result { + let endpoint = self + .endpoint + .clone() + .unwrap_or_else(|| "https://api.cloudflare.com/client/v4".to_string()); + + let sql = match self.sql.clone() { + Some(v) => v, + None => "".to_string(), + }; + + let params = match self.params.clone() { + Some(v) => v, + None => vec![], + }; + + let mut auth = None; + if let Some(token) = &self.token { + auth = Some(format_authorization_by_bearer(token)?) + } + + let client = if let Some(client) = self.http_client.take() { + client + } else { + HttpClient::new().map_err(|err| { + err.with_operation("Builder::build") + .with_context("service", Scheme::D1) + })? + }; + + let root = normalize_root( + self.root + .clone() + .unwrap_or_else(|| "/".to_string()) + .as_str(), + ); + Ok(D1Backend::new(Adapter { + root: root.clone(), + endpoint, + sql, + params, + authorization: auth, + client, + }) + .with_root(&root)) + } +} + +pub type D1Backend = kv::Backend; + +#[derive(Clone)] +pub struct Adapter { + root: String, + endpoint: String, + sql: String, + params: Vec, + authorization: Option, + client: HttpClient, +} + +impl Debug for Adapter { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut ds = f.debug_struct("D1Adapter"); + ds.field("endpoint", &self.endpoint); + ds.field("sql", &self.sql); + ds.field("params", &self.params); + ds.finish() + } +} + +impl Adapter { + fn create_d1_query_request(&self, path: &str) -> Result> { + let p = build_rooted_abs_path(&self.root, path); + let url: String = format!("{}{}", self.endpoint, percent_encode_path(&p)); + + let mut req = Request::post(&url); + if let Some(auth) = &self.authorization { + req = req.header(header::AUTHORIZATION, auth); + } + req = req.header(header::CONTENT_TYPE, "application/json"); + + let json = serde_json::json!({ + "sql": self.sql.clone(), + "params": self.params.clone(), + }); + let body_string = serde_json::to_string(&json).map_err(new_json_serialize_error)?; + let body_bytes = body_string.as_bytes().to_owned(); + + let req = req + .body(AsyncBody::Bytes(body_bytes.into())) + .map_err(new_request_build_error); + req + } +} + +#[async_trait] +impl kv::Adapter for Adapter { + fn metadata(&self) -> kv::Metadata { + kv::Metadata::new( + Scheme::D1, + "D1", + Capability { + stat: true, + read: true, + write: true, + delete: true, + ..Default::default() + }, + ) + } + + async fn get(&self, path: &str) -> Result>> { + let req = self.create_d1_query_request(path)?; + let resp = self.client.send(req).await?; + let status = resp.status(); + match status { + StatusCode::OK | StatusCode::PARTIAL_CONTENT => { + let body = resp.into_body().bytes().await?; + Ok(Some(body.into())) + } + _ => Err(parse_error(resp).await?), + } + } + + async fn set(&self, path: &str, _: &[u8]) -> Result<()> { + let req = self.create_d1_query_request(path)?; + let resp = self.client.send(req).await?; + let status = resp.status(); + match status { + StatusCode::OK | StatusCode::PARTIAL_CONTENT => Ok(()), + _ => Err(parse_error(resp).await?), + } + } + + async fn delete(&self, path: &str) -> Result<()> { + let req = self.create_d1_query_request(path)?; + let resp = self.client.send(req).await?; + let status = resp.status(); + match status { + StatusCode::OK | StatusCode::PARTIAL_CONTENT => Ok(()), + _ => Err(parse_error(resp).await?), + } + } +} diff --git a/core/src/services/d1/docs.md b/core/src/services/d1/docs.md new file mode 100644 index 000000000000..4032ba290288 --- /dev/null +++ b/core/src/services/d1/docs.md @@ -0,0 +1,48 @@ +## Capabilities + +This service can be used to: + +- [x] stat +- [x] read +- [x] write +- [x] create_dir +- [x] delete +- [ ] copy +- [ ] rename +- [ ] ~~list~~ +- [ ] scan +- [ ] ~~presign~~ +- [ ] blocking + +## Configuration + +- `root`: Set the working directory of `OpenDAL` +- `connection_string`: Set the connection string of postgres server +- `table`: Set the table of sqlite +- `key_field`: Set the key field of sqlite +- `value_field`: Set the value field of sqlite + +## Example + +### Via Builder + +```rust +use anyhow::Result; +use opendal::services::Sqlite; +use opendal::Operator; + +#[tokio::main] +async fn main() -> Result<()> { + let mut builder = Sqlite::default(); + builder.root("/"); + builder.connection_string("file//abc.db"); + builder.table("your_table"); + // key field type in the table should be compatible with Rust's &str like text + builder.key_field("key"); + // value field type in the table should be compatible with Rust's Vec like bytea + builder.value_field("value"); + + let op = Operator::new(builder)?.finish(); + Ok(()) +} +``` diff --git a/core/src/services/d1/error.rs b/core/src/services/d1/error.rs new file mode 100644 index 000000000000..967dc19f52a6 --- /dev/null +++ b/core/src/services/d1/error.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use http::Response; +use http::StatusCode; + +use crate::raw::*; +use crate::Error; +use crate::ErrorKind; +use crate::Result; + +/// Parse error response into Error. +pub async fn parse_error(resp: Response) -> Result { + let (parts, body) = resp.into_parts(); + let bs = body.bytes().await?; + + let (kind, retryable) = match parts.status { + StatusCode::NOT_FOUND => (ErrorKind::NotFound, false), + // Some services (like owncloud) return 403 while file locked. + StatusCode::FORBIDDEN => (ErrorKind::PermissionDenied, true), + // Allowing retry for resource locked. + StatusCode::LOCKED => (ErrorKind::Unexpected, true), + StatusCode::INTERNAL_SERVER_ERROR + | StatusCode::BAD_GATEWAY + | StatusCode::SERVICE_UNAVAILABLE + | StatusCode::GATEWAY_TIMEOUT => (ErrorKind::Unexpected, true), + _ => (ErrorKind::Unexpected, false), + }; + + let message = String::from_utf8_lossy(&bs); + + let mut err = Error::new(kind, &message); + + err = with_error_response_context(err, parts); + + if retryable { + err = err.set_temporary(); + } + + Ok(err) +} diff --git a/core/src/services/d1/mod.rs b/core/src/services/d1/mod.rs new file mode 100644 index 000000000000..d5ed11895490 --- /dev/null +++ b/core/src/services/d1/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod backend; +mod error; +pub use backend::D1Builder as D1; diff --git a/core/src/services/mod.rs b/core/src/services/mod.rs index 5293aaa443e9..c58a2d04b0af 100644 --- a/core/src/services/mod.rs +++ b/core/src/services/mod.rs @@ -228,3 +228,7 @@ pub use self::mysql::Mysql; mod sqlite; #[cfg(feature = "services-sqlite")] pub use self::sqlite::Sqlite; + +#[cfg(feature = "services-d1")] +mod d1; +pub use self::d1::D1; diff --git a/core/src/types/scheme.rs b/core/src/types/scheme.rs index ec3298b7bd59..fad9d7f6e207 100644 --- a/core/src/types/scheme.rs +++ b/core/src/types/scheme.rs @@ -41,6 +41,8 @@ pub enum Scheme { Cacache, /// [cos][crate::services::Cos]: Tencent Cloud Object Storage services. Cos, + /// [d1][crate::services::D1]: D1 services + D1, /// [dashmap][crate::services::Dashmap]: dashmap backend support. Dashmap, /// [etcd][crate::services::Etcd]: Etcd Services @@ -159,6 +161,7 @@ impl FromStr for Scheme { "azdls" | "azdfs" | "abfs" => Ok(Scheme::Azdls), "cacache" => Ok(Scheme::Cacache), "cos" => Ok(Scheme::Cos), + "d1" => Ok(Scheme::D1), "dashmap" => Ok(Scheme::Dashmap), "dropbox" => Ok(Scheme::Dropbox), "etcd" => Ok(Scheme::Etcd), @@ -208,6 +211,7 @@ impl From for &'static str { Scheme::Azdls => "Azdls", Scheme::Cacache => "cacache", Scheme::Cos => "cos", + Scheme::D1 => "d1", Scheme::Dashmap => "dashmap", Scheme::Etcd => "etcd", Scheme::Fs => "fs", diff --git a/core/tests/behavior/main.rs b/core/tests/behavior/main.rs index de38efd149da..cc0bf667f154 100644 --- a/core/tests/behavior/main.rs +++ b/core/tests/behavior/main.rs @@ -185,6 +185,8 @@ fn main() -> anyhow::Result<()> { tests.extend(behavior_test::()); #[cfg(feature = "services-sqlite")] tests.extend(behavior_test::()); + #[cfg(feature = "services-d1")] + tests.extend(behavior_test::()); // Don't init logging while building operator which may break cargo // nextest output From 07114731b8be545194c893a3a2bcff56105f7632 Mon Sep 17 00:00:00 2001 From: taobo Date: Wed, 11 Oct 2023 12:04:59 +0800 Subject: [PATCH 2/7] refactor: optimize code --- .env.example | 9 +- core/src/services/d1/backend.rs | 274 ++++++++++++++++++++++++-------- core/src/services/d1/docs.md | 39 +++-- core/src/services/d1/mod.rs | 1 + core/src/services/d1/model.rs | 85 ++++++++++ core/tests/behavior/utils.rs | 2 +- 6 files changed, 326 insertions(+), 84 deletions(-) create mode 100644 core/src/services/d1/model.rs diff --git a/.env.example b/.env.example index f82a58b5ebf6..3355ade2f73f 100644 --- a/.env.example +++ b/.env.example @@ -180,6 +180,9 @@ OPENDAL_SQLITE_KEY_FIELD=key OPENDAL_SQLITE_VALUE_FIELD=data # d1 OPENDAL_D1_TEST=false -OPENDAL_D1_SQL= -OPENDAL_D1_PARAMS= -OPENDAL_D1_TOKEN= \ No newline at end of file +OPENDAL_D1_TOKEN= +OPENDAL_D1_ACCOUNT_IDENTIFIER= +OPENDAL_D1_DATABASE_IDENTIFIER= +OPENDAL_D1_TABLE= +OPENDAL_D1_KEY_FIELD= +OPENDAL_D1_VALUE_FIELD= diff --git a/core/src/services/d1/backend.rs b/core/src/services/d1/backend.rs index 14b8d9c39d28..88679b755c9d 100644 --- a/core/src/services/d1/backend.rs +++ b/core/src/services/d1/backend.rs @@ -23,43 +23,83 @@ use async_trait::async_trait; use http::header; use http::Request; use http::StatusCode; +use serde_json::{de, Value}; use crate::raw::adapters::kv; use crate::raw::*; use crate::*; use super::error::parse_error; +use super::model::D1Response; #[doc = include_str!("docs.md")] #[derive(Default)] pub struct D1Builder { - root: Option, - endpoint: Option, - sql: Option, - params: Option>, token: Option, + account_identifier: Option, + database_identifier: Option, + + endpoint: Option, http_client: Option, + root: Option, + + table: Option, + key_field: Option, + value_field: Option, } impl Debug for D1Builder { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut ds = f.debug_struct("D1Builder"); ds.field("endpoint", &self.endpoint); - ds.field("sql", &self.sql); - ds.field("params", &self.params); ds.field("root", &self.root); + ds.field("table", &self.table); + ds.field("key_field", &self.key_field); + ds.field("value_field", &self.value_field); ds.finish() } } impl D1Builder { + /// Set api token for the cloudflare d1 service. + /// + /// create a api token from [here](https://dash.cloudflare.com/profile/api-tokens) + pub fn token(&mut self, token: &str) -> &mut Self { + if !token.is_empty() { + self.token = Some(token.to_string()); + } + self + } + + /// Set the account identifier for the cloudflare d1 service. + /// + /// get the account identifier from Workers & Pages -> Overview -> Account ID + /// default: "account-identifier" + pub fn account_identifier(&mut self, account_identifier: &str) -> &mut Self { + if !account_identifier.is_empty() { + self.account_identifier = Some(account_identifier.to_string()); + } + self + } + + /// Set the database identifier for the cloudflare d1 service. + /// + /// get the database identifier from Workers & Pages -> D1 -> [Your Database] -> Database ID + /// default: "database-identifier" + pub fn database_identifier(&mut self, database_identifier: &str) -> &mut Self { + if !database_identifier.is_empty() { + self.database_identifier = Some(database_identifier.to_string()); + } + self + } + /// Set endpoint for http backend. /// /// For more information, please refer to [D1 Database API](https://developers.cloudflare.com/api/operations/cloudflare-d1-query-database) /// default: "https://api.cloudflare.com/client/v4" pub fn endpoint(&mut self, v: &str) -> &mut Self { if !v.is_empty() { - self.endpoint = Some(v.trim_end_matches('/').to_string()); + self.endpoint = Some(v.to_string()); } self } @@ -74,29 +114,32 @@ impl D1Builder { self } - /// Set D1 execution sql. - pub fn sql(&mut self, sql: &str) -> &mut Self { - if !sql.is_empty() { - self.sql = Some(sql.to_string()); + /// Set the table name of the d1 service to read/write. + /// + /// Default to `kv` if not specified. + pub fn table(&mut self, table: &str) -> &mut Self { + if !table.is_empty() { + self.table = Some(table.to_owned()); } self } - /// Set the sql value field of the d1 service. + /// Set the key field name of the d1 service to read/write. /// - /// default: vec![] - pub fn params(&mut self, params: Vec) -> &mut Self { - if !params.is_empty() { - self.params = Some(params); + /// Default to `key` if not specified. + pub fn key_field(&mut self, key_field: &str) -> &mut Self { + if !key_field.is_empty() { + self.key_field = Some(key_field.to_string()); } self } - /// Set the bearer token for the d1 service. - /// create a bearer token from [here](https://dash.cloudflare.com/profile/api-tokens) - pub fn token(&mut self, token: &str) -> &mut Self { - if !token.is_empty() { - self.token = Some(token.to_string()); + /// Set the value field name of the d1 service to read/write. + /// + /// Default to `value` if not specified. + pub fn value_field(&mut self, value_field: &str) -> &mut Self { + if !value_field.is_empty() { + self.value_field = Some(value_field.to_string()); } self } @@ -108,36 +151,42 @@ impl Builder for D1Builder { fn from_map(map: HashMap) -> Self { let mut builder = D1Builder::default(); + map.get("token").map(|v| builder.token(v)); + map.get("account_identifier") + .map(|v| builder.account_identifier(v)); + map.get("database_identifier") + .map(|v| builder.database_identifier(v)); + map.get("endpoint").map(|v| builder.endpoint(v)); - map.get("sql").map(|v| builder.sql(v)); - map.get("params") - .map(|v| builder.params(v.split(",").map(|s| s.to_string()).collect())); map.get("root").map(|v| builder.root(v)); - map.get("token").map(|v| builder.token(v)); + + map.get("table").map(|v| builder.table(v)); + map.get("key_field").map(|v| builder.key_field(v)); + map.get("value_field").map(|v| builder.value_field(v)); builder } fn build(&mut self) -> Result { + let mut authorization = None; + if let Some(token) = &self.token { + authorization = Some(format_authorization_by_bearer(token)?) + } + + let account_identifier = self + .account_identifier + .clone() + .unwrap_or_else(|| "account-identifier".to_string()); + + let database_identifier = self + .database_identifier + .clone() + .unwrap_or_else(|| "database-identifier".to_string()); + let endpoint = self .endpoint .clone() .unwrap_or_else(|| "https://api.cloudflare.com/client/v4".to_string()); - let sql = match self.sql.clone() { - Some(v) => v, - None => "".to_string(), - }; - - let params = match self.params.clone() { - Some(v) => v, - None => vec![], - }; - - let mut auth = None; - if let Some(token) = &self.token { - auth = Some(format_authorization_by_bearer(token)?) - } - let client = if let Some(client) = self.http_client.take() { client } else { @@ -147,6 +196,15 @@ impl Builder for D1Builder { })? }; + let table = self.table.clone().unwrap_or_else(|| "kv".to_string()); + + let key_field = self.key_field.clone().unwrap_or_else(|| "key".to_string()); + + let value_field = self + .value_field + .clone() + .unwrap_or_else(|| "value".to_string()); + let root = normalize_root( self.root .clone() @@ -154,12 +212,14 @@ impl Builder for D1Builder { .as_str(), ); Ok(D1Backend::new(Adapter { - root: root.clone(), + authorization, + account_identifier, + database_identifier, endpoint, - sql, - params, - authorization: auth, client, + table, + key_field, + value_field, }) .with_root(&root)) } @@ -169,27 +229,35 @@ pub type D1Backend = kv::Backend; #[derive(Clone)] pub struct Adapter { - root: String, - endpoint: String, - sql: String, - params: Vec, authorization: Option, + account_identifier: String, + database_identifier: String, + + endpoint: String, client: HttpClient, + + table: String, + key_field: String, + value_field: String, } impl Debug for Adapter { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut ds = f.debug_struct("D1Adapter"); ds.field("endpoint", &self.endpoint); - ds.field("sql", &self.sql); - ds.field("params", &self.params); + ds.field("table", &self.table); + ds.field("key_field", &self.key_field); + ds.field("value_field", &self.value_field); ds.finish() } } impl Adapter { - fn create_d1_query_request(&self, path: &str) -> Result> { - let p = build_rooted_abs_path(&self.root, path); + fn create_d1_query_request(&self, sql: &str, params: Vec) -> Result> { + let p = format!( + "/accounts/{}/d1/database/{}/query", + self.account_identifier, self.database_identifier + ); let url: String = format!("{}{}", self.endpoint, percent_encode_path(&p)); let mut req = Request::post(&url); @@ -199,16 +267,14 @@ impl Adapter { req = req.header(header::CONTENT_TYPE, "application/json"); let json = serde_json::json!({ - "sql": self.sql.clone(), - "params": self.params.clone(), + "sql": sql, + "params": params, }); + let body_string = serde_json::to_string(&json).map_err(new_json_serialize_error)?; let body_bytes = body_string.as_bytes().to_owned(); - - let req = req - .body(AsyncBody::Bytes(body_bytes.into())) - .map_err(new_request_build_error); - req + req.body(AsyncBody::Bytes(body_bytes.into())) + .map_err(new_request_build_error) } } @@ -217,32 +283,74 @@ impl kv::Adapter for Adapter { fn metadata(&self) -> kv::Metadata { kv::Metadata::new( Scheme::D1, - "D1", + &self.table, Capability { - stat: true, read: true, write: true, - delete: true, ..Default::default() }, ) } async fn get(&self, path: &str) -> Result>> { - let req = self.create_d1_query_request(path)?; + let query = format!( + "SELECT {} FROM {} WHERE {} = ? LIMIT 1", + self.value_field, self.table, self.key_field + ); + let req = self.create_d1_query_request(&query, vec![path.into()])?; + let resp = self.client.send(req).await?; let status = resp.status(); match status { StatusCode::OK | StatusCode::PARTIAL_CONTENT => { let body = resp.into_body().bytes().await?; - Ok(Some(body.into())) + let body = de::from_slice::(&body); + if let Ok(body) = body { + if body.success { + if let Some(result) = body.result.get(0) { + if let Some(value) = result.results.get(0) { + match value { + Value::Object(s) => { + let value = s.get(&self.value_field); + match value { + Some(Value::Array(s)) => { + let mut v = Vec::new(); + for i in s { + if let Value::Number(n) = i { + v.push(n.as_u64().unwrap() as u8); + } + } + return Ok(Some(v)); + } + _ => return Ok(None), + } + } + _ => return Ok(None), + } + } + } + } + } + Ok(None) } _ => Err(parse_error(resp).await?), } } - async fn set(&self, path: &str, _: &[u8]) -> Result<()> { - let req = self.create_d1_query_request(path)?; + async fn set(&self, path: &str, value: &[u8]) -> Result<()> { + let table = &self.table; + let key_field = &self.key_field; + let value_field = &self.value_field; + let query = format!( + "INSERT INTO {table} ({key_field}, {value_field}) \ + VALUES (?, ?) \ + ON CONFLICT ({key_field}) \ + DO UPDATE SET {value_field} = EXCLUDED.{value_field}", + ); + + let params = vec![path.into(), value.into()]; + let req = self.create_d1_query_request(&query, params)?; + let resp = self.client.send(req).await?; let status = resp.status(); match status { @@ -252,7 +360,9 @@ impl kv::Adapter for Adapter { } async fn delete(&self, path: &str) -> Result<()> { - let req = self.create_d1_query_request(path)?; + let query = format!("DELETE FROM {} WHERE {} = ?", self.table, self.key_field); + let req = self.create_d1_query_request(&query, vec![path.into()])?; + let resp = self.client.send(req).await?; let status = resp.status(); match status { @@ -261,3 +371,33 @@ impl kv::Adapter for Adapter { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn test_something_async() -> Result<()> { + let mut builder = D1Builder::default(); + builder + .token("AvVw_T7HbYZz-tWpVV7ytwQqFkD0IPv60grGLA_v") + .account_identifier("b386f5d906b87949002b545dec889cd5") + .database_identifier("16aba954-2a17-4dd5-94bc-bbea9232a889") + .table("Customers") + .key_field("CustomerID") + .value_field("CompanyName"); + + let op = Operator::new(builder)?.finish(); + let source_path = "ALFKI"; + // set value to d1 "opendal test value" as Vec + let value = "opendal test value".as_bytes(); + // write value to d1, the key is source_path + op.write(source_path, value).await?; + // read value from d1, the key is source_path + let v = op.read(source_path).await?; + assert_eq!(v, value); + // delete value from d1, the key is source_path + op.delete(source_path).await?; + Ok(()) + } +} diff --git a/core/src/services/d1/docs.md b/core/src/services/d1/docs.md index 4032ba290288..5c2bcbeba24c 100644 --- a/core/src/services/d1/docs.md +++ b/core/src/services/d1/docs.md @@ -17,10 +17,13 @@ This service can be used to: ## Configuration - `root`: Set the working directory of `OpenDAL` -- `connection_string`: Set the connection string of postgres server -- `table`: Set the table of sqlite -- `key_field`: Set the key field of sqlite -- `value_field`: Set the value field of sqlite +- `token`: Set the token of cloudflare api +- `account_identifier`: Set the account identifier of d1 +- `database_identifier`: Set the database identifier of d1 +- `endpoint`: Set the endpoint of d1 service +- `table`: Set the table name of the d1 service to read/write +- `key_field`: Set the key field of d1 +- `value_field`: Set the value field of d1 ## Example @@ -28,21 +31,31 @@ This service can be used to: ```rust use anyhow::Result; -use opendal::services::Sqlite; +use opendal::services::D1; use opendal::Operator; #[tokio::main] async fn main() -> Result<()> { - let mut builder = Sqlite::default(); - builder.root("/"); - builder.connection_string("file//abc.db"); - builder.table("your_table"); - // key field type in the table should be compatible with Rust's &str like text - builder.key_field("key"); - // value field type in the table should be compatible with Rust's Vec like bytea - builder.value_field("value"); + let mut builder = D1::default(); + builder + .token("token") + .account_identifier("account_identifier") + .database_identifier("database_identifier") + .table("table") + .key_field("key_field") + .value_field("value_field"); let op = Operator::new(builder)?.finish(); + let source_path = "ALFKI"; + // set value to d1 "opendal test value" as Vec + let value = "opendal test value".as_bytes(); + // write value to d1, the key is source_path + op.write(source_path, value).await?; + // read value from d1, the key is source_path + let v = op.read(source_path).await?; + assert_eq!(v, value); + // delete value from d1, the key is source_path + op.delete(source_path).await?; Ok(()) } ``` diff --git a/core/src/services/d1/mod.rs b/core/src/services/d1/mod.rs index d5ed11895490..9163a135c6c9 100644 --- a/core/src/services/d1/mod.rs +++ b/core/src/services/d1/mod.rs @@ -17,4 +17,5 @@ mod backend; mod error; +mod model; pub use backend::D1Builder as D1; diff --git a/core/src/services/d1/model.rs b/core/src/services/d1/model.rs new file mode 100644 index 000000000000..03b6df829c90 --- /dev/null +++ b/core/src/services/d1/model.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::Deserialize; +use serde_json::Value; +use std::fmt::Debug; + +/// response data from d1 +#[derive(Deserialize, Debug)] +pub struct D1Response { + pub result: Vec, + pub success: bool, +} + +#[derive(Deserialize, Debug)] +pub struct D1Result { + pub meta: Meta, + pub results: Vec, + pub success: bool, +} + +#[derive(Deserialize, Debug)] +pub struct Meta { + pub served_by: String, + pub duration: f64, + pub changes: i32, + pub last_row_id: i32, + pub changed_db: bool, + pub size_after: i32, + pub rows_read: i32, + pub rows_written: i32, +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_deserialize_get_object_json_response() { + let data = r#" + { + "result": [ + { + "results": [ + { + "CustomerId": "4", + "CompanyName": "Around the Horn", + "ContactName": "Thomas Hardy" + } + ], + "success": true, + "meta": { + "served_by": "v3-prod", + "duration": 0.2147, + "changes": 0, + "last_row_id": 0, + "changed_db": false, + "size_after": 2162688, + "rows_read": 3, + "rows_written": 2 + } + } + ], + "success": true, + "errors": [], + "messages": [] + }"#; + let response: D1Response = serde_json::from_str(data).unwrap(); + println!("{:?}", response.result[0].results[0]); + } +} diff --git a/core/tests/behavior/utils.rs b/core/tests/behavior/utils.rs index 7cdd032a1b80..acad6c87a41e 100644 --- a/core/tests/behavior/utils.rs +++ b/core/tests/behavior/utils.rs @@ -107,7 +107,7 @@ pub fn gen_bytes_with_range(range: impl SampleRange) -> (Vec, usize) } pub fn gen_bytes() -> (Vec, usize) { - gen_bytes_with_range(1..4 * 1024 * 1024) + gen_bytes_with_range(1..4 * 20) } pub fn gen_fixed_bytes(size: usize) -> Vec { From f75c7a8d78a24ad5c8264afef27298c3f53c1753 Mon Sep 17 00:00:00 2001 From: taobo Date: Thu, 12 Oct 2023 17:25:28 +0800 Subject: [PATCH 3/7] fix: delete extra code --- core/Cargo.toml | 1 - core/src/services/d1/backend.rs | 73 +++++++++++---------------------- core/src/services/mod.rs | 1 + core/tests/behavior/utils.rs | 2 +- 4 files changed, 25 insertions(+), 52 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 1cdc281078d0..5e20e949e731 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -50,7 +50,6 @@ default = [ "services-s3", "services-webdav", "services-webhdfs", - "services-d1", ] # Build docs or not. diff --git a/core/src/services/d1/backend.rs b/core/src/services/d1/backend.rs index 88679b755c9d..218a8aa7aeb3 100644 --- a/core/src/services/d1/backend.rs +++ b/core/src/services/d1/backend.rs @@ -304,31 +304,34 @@ impl kv::Adapter for Adapter { match status { StatusCode::OK | StatusCode::PARTIAL_CONTENT => { let body = resp.into_body().bytes().await?; - let body = de::from_slice::(&body); - if let Ok(body) = body { - if body.success { - if let Some(result) = body.result.get(0) { - if let Some(value) = result.results.get(0) { - match value { - Value::Object(s) => { - let value = s.get(&self.value_field); - match value { - Some(Value::Array(s)) => { - let mut v = Vec::new(); - for i in s { - if let Value::Number(n) = i { - v.push(n.as_u64().unwrap() as u8); - } - } - return Ok(Some(v)); - } - _ => return Ok(None), + if let Ok(body) = de::from_slice::(&body) { + if !body.success { + return Ok(None); + } + let Some(result) = body.result.get(0) else { + return Ok(None); + }; + let Some(value) = result.results.get(0) else { + return Ok(None); + }; + + match value { + Value::Object(s) => { + let value = s.get(&self.value_field); + match value { + Some(Value::Array(s)) => { + let mut v = Vec::new(); + for i in s { + if let Value::Number(n) = i { + v.push(n.as_u64().unwrap() as u8); } } - _ => return Ok(None), + return Ok(Some(v)); } + _ => return Ok(None), } } + _ => return Ok(None), } } Ok(None) @@ -371,33 +374,3 @@ impl kv::Adapter for Adapter { } } } - -#[cfg(test)] -mod test { - use super::*; - - #[tokio::test] - async fn test_something_async() -> Result<()> { - let mut builder = D1Builder::default(); - builder - .token("AvVw_T7HbYZz-tWpVV7ytwQqFkD0IPv60grGLA_v") - .account_identifier("b386f5d906b87949002b545dec889cd5") - .database_identifier("16aba954-2a17-4dd5-94bc-bbea9232a889") - .table("Customers") - .key_field("CustomerID") - .value_field("CompanyName"); - - let op = Operator::new(builder)?.finish(); - let source_path = "ALFKI"; - // set value to d1 "opendal test value" as Vec - let value = "opendal test value".as_bytes(); - // write value to d1, the key is source_path - op.write(source_path, value).await?; - // read value from d1, the key is source_path - let v = op.read(source_path).await?; - assert_eq!(v, value); - // delete value from d1, the key is source_path - op.delete(source_path).await?; - Ok(()) - } -} diff --git a/core/src/services/mod.rs b/core/src/services/mod.rs index c58a2d04b0af..93da8b917b0a 100644 --- a/core/src/services/mod.rs +++ b/core/src/services/mod.rs @@ -231,4 +231,5 @@ pub use self::sqlite::Sqlite; #[cfg(feature = "services-d1")] mod d1; +#[cfg(feature = "services-d1")] pub use self::d1::D1; diff --git a/core/tests/behavior/utils.rs b/core/tests/behavior/utils.rs index acad6c87a41e..7cdd032a1b80 100644 --- a/core/tests/behavior/utils.rs +++ b/core/tests/behavior/utils.rs @@ -107,7 +107,7 @@ pub fn gen_bytes_with_range(range: impl SampleRange) -> (Vec, usize) } pub fn gen_bytes() -> (Vec, usize) { - gen_bytes_with_range(1..4 * 20) + gen_bytes_with_range(1..4 * 1024 * 1024) } pub fn gen_fixed_bytes(size: usize) -> Vec { From bbe1b4907e8418ec923b470e5d78702d5a69d11c Mon Sep 17 00:00:00 2001 From: taobo Date: Thu, 12 Oct 2023 22:21:50 +0800 Subject: [PATCH 4/7] fix: optimize code --- .env.example | 4 +- core/src/services/d1/backend.rs | 157 +++++++++++++++----------------- core/src/services/d1/docs.md | 4 +- core/src/services/d1/error.rs | 25 ++++- core/src/services/d1/model.rs | 6 +- 5 files changed, 104 insertions(+), 92 deletions(-) diff --git a/.env.example b/.env.example index 3355ade2f73f..b42f03bddcf0 100644 --- a/.env.example +++ b/.env.example @@ -181,8 +181,8 @@ OPENDAL_SQLITE_VALUE_FIELD=data # d1 OPENDAL_D1_TEST=false OPENDAL_D1_TOKEN= -OPENDAL_D1_ACCOUNT_IDENTIFIER= -OPENDAL_D1_DATABASE_IDENTIFIER= +OPENDAL_D1_ACCOUNT_ID= +OPENDAL_D1_DATABASE_ID= OPENDAL_D1_TABLE=
OPENDAL_D1_KEY_FIELD= OPENDAL_D1_VALUE_FIELD= diff --git a/core/src/services/d1/backend.rs b/core/src/services/d1/backend.rs index 218a8aa7aeb3..5eafba855a02 100644 --- a/core/src/services/d1/backend.rs +++ b/core/src/services/d1/backend.rs @@ -27,19 +27,19 @@ use serde_json::{de, Value}; use crate::raw::adapters::kv; use crate::raw::*; +use crate::ErrorKind; use crate::*; -use super::error::parse_error; +use super::error::{parse_d1_error, parse_error}; use super::model::D1Response; #[doc = include_str!("docs.md")] #[derive(Default)] pub struct D1Builder { token: Option, - account_identifier: Option, - database_identifier: Option, + account_id: Option, + database_id: Option, - endpoint: Option, http_client: Option, root: Option, @@ -51,7 +51,6 @@ pub struct D1Builder { impl Debug for D1Builder { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut ds = f.debug_struct("D1Builder"); - ds.field("endpoint", &self.endpoint); ds.field("root", &self.root); ds.field("table", &self.table); ds.field("key_field", &self.key_field); @@ -74,10 +73,10 @@ impl D1Builder { /// Set the account identifier for the cloudflare d1 service. /// /// get the account identifier from Workers & Pages -> Overview -> Account ID - /// default: "account-identifier" - pub fn account_identifier(&mut self, account_identifier: &str) -> &mut Self { - if !account_identifier.is_empty() { - self.account_identifier = Some(account_identifier.to_string()); + /// If not specified, it will return an error when building. + pub fn account_id(&mut self, account_id: &str) -> &mut Self { + if !account_id.is_empty() { + self.account_id = Some(account_id.to_string()); } self } @@ -85,21 +84,10 @@ impl D1Builder { /// Set the database identifier for the cloudflare d1 service. /// /// get the database identifier from Workers & Pages -> D1 -> [Your Database] -> Database ID - /// default: "database-identifier" - pub fn database_identifier(&mut self, database_identifier: &str) -> &mut Self { - if !database_identifier.is_empty() { - self.database_identifier = Some(database_identifier.to_string()); - } - self - } - - /// Set endpoint for http backend. - /// - /// For more information, please refer to [D1 Database API](https://developers.cloudflare.com/api/operations/cloudflare-d1-query-database) - /// default: "https://api.cloudflare.com/client/v4" - pub fn endpoint(&mut self, v: &str) -> &mut Self { - if !v.is_empty() { - self.endpoint = Some(v.to_string()); + /// If not specified, it will return an error when building. + pub fn database_id(&mut self, database_id: &str) -> &mut Self { + if !database_id.is_empty() { + self.database_id = Some(database_id.to_string()); } self } @@ -116,7 +104,7 @@ impl D1Builder { /// Set the table name of the d1 service to read/write. /// - /// Default to `kv` if not specified. + /// If not specified, it will return an error when building. pub fn table(&mut self, table: &str) -> &mut Self { if !table.is_empty() { self.table = Some(table.to_owned()); @@ -152,14 +140,10 @@ impl Builder for D1Builder { fn from_map(map: HashMap) -> Self { let mut builder = D1Builder::default(); map.get("token").map(|v| builder.token(v)); - map.get("account_identifier") - .map(|v| builder.account_identifier(v)); - map.get("database_identifier") - .map(|v| builder.database_identifier(v)); + map.get("account_id").map(|v| builder.account_id(v)); + map.get("database_id").map(|v| builder.database_id(v)); - map.get("endpoint").map(|v| builder.endpoint(v)); map.get("root").map(|v| builder.root(v)); - map.get("table").map(|v| builder.table(v)); map.get("key_field").map(|v| builder.key_field(v)); map.get("value_field").map(|v| builder.value_field(v)); @@ -172,20 +156,19 @@ impl Builder for D1Builder { authorization = Some(format_authorization_by_bearer(token)?) } - let account_identifier = self - .account_identifier - .clone() - .unwrap_or_else(|| "account-identifier".to_string()); - - let database_identifier = self - .database_identifier - .clone() - .unwrap_or_else(|| "database-identifier".to_string()); + let Some(account_id) = self.account_id.clone() else { + return Err(Error::new( + ErrorKind::ConfigInvalid, + "account_id is required", + )); + }; - let endpoint = self - .endpoint - .clone() - .unwrap_or_else(|| "https://api.cloudflare.com/client/v4".to_string()); + let Some(database_id) = self.database_id.clone() else { + return Err(Error::new( + ErrorKind::ConfigInvalid, + "database_id is required", + )); + }; let client = if let Some(client) = self.http_client.take() { client @@ -196,7 +179,9 @@ impl Builder for D1Builder { })? }; - let table = self.table.clone().unwrap_or_else(|| "kv".to_string()); + let Some(table) = self.table.clone() else { + return Err(Error::new(ErrorKind::ConfigInvalid, "table is required")); + }; let key_field = self.key_field.clone().unwrap_or_else(|| "key".to_string()); @@ -213,9 +198,8 @@ impl Builder for D1Builder { ); Ok(D1Backend::new(Adapter { authorization, - account_identifier, - database_identifier, - endpoint, + account_id, + database_id, client, table, key_field, @@ -230,12 +214,10 @@ pub type D1Backend = kv::Backend; #[derive(Clone)] pub struct Adapter { authorization: Option, - account_identifier: String, - database_identifier: String, + account_id: String, + database_id: String, - endpoint: String, client: HttpClient, - table: String, key_field: String, value_field: String, @@ -244,7 +226,6 @@ pub struct Adapter { impl Debug for Adapter { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut ds = f.debug_struct("D1Adapter"); - ds.field("endpoint", &self.endpoint); ds.field("table", &self.table); ds.field("key_field", &self.key_field); ds.field("value_field", &self.value_field); @@ -256,9 +237,13 @@ impl Adapter { fn create_d1_query_request(&self, sql: &str, params: Vec) -> Result> { let p = format!( "/accounts/{}/d1/database/{}/query", - self.account_identifier, self.database_identifier + self.account_id, self.database_id + ); + let url: String = format!( + "{}{}", + "https://api.cloudflare.com/client/v4", + percent_encode_path(&p) ); - let url: String = format!("{}{}", self.endpoint, percent_encode_path(&p)); let mut req = Request::post(&url); if let Some(auth) = &self.authorization { @@ -271,9 +256,8 @@ impl Adapter { "params": params, }); - let body_string = serde_json::to_string(&json).map_err(new_json_serialize_error)?; - let body_bytes = body_string.as_bytes().to_owned(); - req.body(AsyncBody::Bytes(body_bytes.into())) + let body = serde_json::to_vec(&json).map_err(new_json_serialize_error)?; + req.body(AsyncBody::Bytes(body.into())) .map_err(new_request_build_error) } } @@ -304,37 +288,40 @@ impl kv::Adapter for Adapter { match status { StatusCode::OK | StatusCode::PARTIAL_CONTENT => { let body = resp.into_body().bytes().await?; - if let Ok(body) = de::from_slice::(&body) { - if !body.success { - return Ok(None); - } - let Some(result) = body.result.get(0) else { - return Ok(None); - }; - let Some(value) = result.results.get(0) else { - return Ok(None); - }; - - match value { - Value::Object(s) => { - let value = s.get(&self.value_field); - match value { - Some(Value::Array(s)) => { - let mut v = Vec::new(); - for i in s { - if let Value::Number(n) = i { - v.push(n.as_u64().unwrap() as u8); - } - } - return Ok(Some(v)); - } - _ => return Ok(None), + let Ok(body) = de::from_slice::(&body) else { + return Err(Error::new( + ErrorKind::Unexpected, + "failed to parse response", + )); + }; + if !body.success { + return Err(parse_d1_error(&body).await?); + } + let Some(result) = body.result.get(0) else { + return Ok(None); + }; + let Some(value) = result.results.get(0) else { + return Ok(None); + }; + + let value = value.get(&self.value_field); + match value { + Some(Value::Array(s)) => { + let mut v = Vec::new(); + for i in s { + if let Value::Number(n) = i { + v.push(n.as_u64().unwrap() as u8); } } - _ => return Ok(None), + return Ok(Some(v)); + } + _ => { + return Err(Error::new( + ErrorKind::Unexpected, + "failed to parse response", + )) } } - Ok(None) } _ => Err(parse_error(resp).await?), } diff --git a/core/src/services/d1/docs.md b/core/src/services/d1/docs.md index 5c2bcbeba24c..17e3c71232cc 100644 --- a/core/src/services/d1/docs.md +++ b/core/src/services/d1/docs.md @@ -39,8 +39,8 @@ async fn main() -> Result<()> { let mut builder = D1::default(); builder .token("token") - .account_identifier("account_identifier") - .database_identifier("database_identifier") + .account_id("account_id") + .database_id("database_id") .table("table") .key_field("key_field") .value_field("value_field"); diff --git a/core/src/services/d1/error.rs b/core/src/services/d1/error.rs index 967dc19f52a6..76ddbc9fbcb0 100644 --- a/core/src/services/d1/error.rs +++ b/core/src/services/d1/error.rs @@ -23,6 +23,10 @@ use crate::Error; use crate::ErrorKind; use crate::Result; +use serde_json::de; + +use super::model::D1Response; + /// Parse error response into Error. pub async fn parse_error(resp: Response) -> Result { let (parts, body) = resp.into_parts(); @@ -41,8 +45,16 @@ pub async fn parse_error(resp: Response) -> Result { _ => (ErrorKind::Unexpected, false), }; - let message = String::from_utf8_lossy(&bs); + let message = "failed to parse error response"; + let Ok(body) = de::from_slice::(&bs) else { + return Ok(Error::new(kind, &message)); + }; + let message = body.errors.get(0).map_or(message.to_string(), |e| { + e.get("message").map_or(message.to_string(), |m| { + m.as_str().unwrap_or(message).to_string() + }) + }); let mut err = Error::new(kind, &message); err = with_error_response_context(err, parts); @@ -53,3 +65,14 @@ pub async fn parse_error(resp: Response) -> Result { Ok(err) } + +/// Parse error D1Response into Error. +pub async fn parse_d1_error(resp: &D1Response) -> Result { + let message = "failed to parse error response"; + let message = resp.errors.get(0).map_or(message.to_string(), |e| { + e.get("message").map_or(message.to_string(), |m| { + m.as_str().unwrap_or(message).to_string() + }) + }); + Ok(Error::new(ErrorKind::Unexpected, &message)) +} diff --git a/core/src/services/d1/model.rs b/core/src/services/d1/model.rs index 03b6df829c90..0b5e462e53f7 100644 --- a/core/src/services/d1/model.rs +++ b/core/src/services/d1/model.rs @@ -16,7 +16,7 @@ // under the License. use serde::Deserialize; -use serde_json::Value; +use serde_json::{Map, Value}; use std::fmt::Debug; /// response data from d1 @@ -24,12 +24,14 @@ use std::fmt::Debug; pub struct D1Response { pub result: Vec, pub success: bool, + pub errors: Vec>, + pub messages: Vec>, } #[derive(Deserialize, Debug)] pub struct D1Result { pub meta: Meta, - pub results: Vec, + pub results: Vec>, pub success: bool, } From e8b3575bc7712778a6a621a63c67144809a1a64b Mon Sep 17 00:00:00 2001 From: taobo Date: Fri, 13 Oct 2023 13:43:47 +0800 Subject: [PATCH 5/7] refactor: errors and get method --- core/src/services/d1/backend.rs | 40 +++------------------ core/src/services/d1/error.rs | 43 +++++++++++++---------- core/src/services/d1/model.rs | 62 +++++++++++++++++++++++++++++++-- 3 files changed, 87 insertions(+), 58 deletions(-) diff --git a/core/src/services/d1/backend.rs b/core/src/services/d1/backend.rs index 5eafba855a02..faf0c472efe2 100644 --- a/core/src/services/d1/backend.rs +++ b/core/src/services/d1/backend.rs @@ -23,14 +23,14 @@ use async_trait::async_trait; use http::header; use http::Request; use http::StatusCode; -use serde_json::{de, Value}; +use serde_json::Value; use crate::raw::adapters::kv; use crate::raw::*; use crate::ErrorKind; use crate::*; -use super::error::{parse_d1_error, parse_error}; +use super::error::parse_error; use super::model::D1Response; #[doc = include_str!("docs.md")] @@ -288,40 +288,8 @@ impl kv::Adapter for Adapter { match status { StatusCode::OK | StatusCode::PARTIAL_CONTENT => { let body = resp.into_body().bytes().await?; - let Ok(body) = de::from_slice::(&body) else { - return Err(Error::new( - ErrorKind::Unexpected, - "failed to parse response", - )); - }; - if !body.success { - return Err(parse_d1_error(&body).await?); - } - let Some(result) = body.result.get(0) else { - return Ok(None); - }; - let Some(value) = result.results.get(0) else { - return Ok(None); - }; - - let value = value.get(&self.value_field); - match value { - Some(Value::Array(s)) => { - let mut v = Vec::new(); - for i in s { - if let Value::Number(n) = i { - v.push(n.as_u64().unwrap() as u8); - } - } - return Ok(Some(v)); - } - _ => { - return Err(Error::new( - ErrorKind::Unexpected, - "failed to parse response", - )) - } - } + let d1_response = D1Response::parse(&body)?; + Ok(d1_response.get_result(&self.value_field)) } _ => Err(parse_error(resp).await?), } diff --git a/core/src/services/d1/error.rs b/core/src/services/d1/error.rs index 76ddbc9fbcb0..2a2f75de49bc 100644 --- a/core/src/services/d1/error.rs +++ b/core/src/services/d1/error.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use bytes::Buf; use http::Response; use http::StatusCode; @@ -25,6 +26,7 @@ use crate::Result; use serde_json::de; +use super::model::D1Error; use super::model::D1Response; /// Parse error response into Error. @@ -32,7 +34,7 @@ pub async fn parse_error(resp: Response) -> Result { let (parts, body) = resp.into_parts(); let bs = body.bytes().await?; - let (kind, retryable) = match parts.status { + let (mut kind, mut retryable) = match parts.status { StatusCode::NOT_FOUND => (ErrorKind::NotFound, false), // Some services (like owncloud) return 403 while file locked. StatusCode::FORBIDDEN => (ErrorKind::PermissionDenied, true), @@ -45,16 +47,14 @@ pub async fn parse_error(resp: Response) -> Result { _ => (ErrorKind::Unexpected, false), }; - let message = "failed to parse error response"; - let Ok(body) = de::from_slice::(&bs) else { - return Ok(Error::new(kind, &message)); - }; + let (message, d1_err) = de::from_reader::<_, D1Response>(bs.clone().reader()) + .map(|d1_err| (format!("{d1_err:?}"), Some(d1_err))) + .unwrap_or_else(|_| (String::from_utf8_lossy(&bs).into_owned(), None)); + + if let Some(d1_err) = d1_err { + (kind, retryable) = parse_d1_error_code(d1_err.errors).unwrap_or((kind, retryable)); + } - let message = body.errors.get(0).map_or(message.to_string(), |e| { - e.get("message").map_or(message.to_string(), |m| { - m.as_str().unwrap_or(message).to_string() - }) - }); let mut err = Error::new(kind, &message); err = with_error_response_context(err, parts); @@ -66,13 +66,18 @@ pub async fn parse_error(resp: Response) -> Result { Ok(err) } -/// Parse error D1Response into Error. -pub async fn parse_d1_error(resp: &D1Response) -> Result { - let message = "failed to parse error response"; - let message = resp.errors.get(0).map_or(message.to_string(), |e| { - e.get("message").map_or(message.to_string(), |m| { - m.as_str().unwrap_or(message).to_string() - }) - }); - Ok(Error::new(ErrorKind::Unexpected, &message)) +pub fn parse_d1_error_code(errors: Vec) -> Option<(ErrorKind, bool)> { + if errors.is_empty() { + return None; + } + + match errors[0].code { + // The request is malformed: failed to decode id. + 7400 => Some((ErrorKind::Unexpected, false)), + // no such column: Xxxx. + 7500 => Some((ErrorKind::NotFound, false)), + // Authentication error. + 10000 => Some((ErrorKind::PermissionDenied, false)), + _ => None, + } } diff --git a/core/src/services/d1/model.rs b/core/src/services/d1/model.rs index 0b5e462e53f7..4d22319746c5 100644 --- a/core/src/services/d1/model.rs +++ b/core/src/services/d1/model.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use serde::Deserialize; +use crate::Error; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::fmt::Debug; @@ -24,8 +26,50 @@ use std::fmt::Debug; pub struct D1Response { pub result: Vec, pub success: bool, - pub errors: Vec>, - pub messages: Vec>, + pub errors: Vec, + pub messages: Vec, +} + +impl D1Response { + pub fn parse(bs: &Bytes) -> Result { + let response: D1Response = serde_json::from_slice(bs).map_err(|e| { + Error::new( + crate::ErrorKind::Unexpected, + &format!("failed to parse error response: {}", e), + ) + })?; + + if !response.success { + return Err(Error::new( + crate::ErrorKind::Unexpected, + &String::from_utf8_lossy(&bs).into_owned(), + )); + } + Ok(response) + } + + pub fn get_result(&self, path: &str) -> Option> { + if self.result.len() > 0 { + let result = &self.result[0]; + if result.results.len() > 0 { + let result = &result.results[0]; + let value = result.get(path); + match value { + Some(Value::Array(s)) => { + let mut v = Vec::new(); + for i in s { + if let Value::Number(n) = i { + v.push(n.as_u64().unwrap() as u8); + } + } + return Some(v); + } + _ => return None, + } + } + } + None + } } #[derive(Deserialize, Debug)] @@ -47,6 +91,18 @@ pub struct Meta { pub rows_written: i32, } +#[derive(Clone, Deserialize, Debug, Serialize)] +pub struct D1Error { + pub message: String, + pub code: i32, +} + +#[derive(Deserialize, Debug)] +pub struct D1Message { + pub message: String, + pub code: i32, +} + #[cfg(test)] mod test { use super::*; From bd98dea3caa42947e3b24e5dbdeb71b693acbdd2 Mon Sep 17 00:00:00 2001 From: taobo Date: Fri, 13 Oct 2023 14:34:27 +0800 Subject: [PATCH 6/7] fix: let clippy http --- core/src/services/d1/model.rs | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/core/src/services/d1/model.rs b/core/src/services/d1/model.rs index 4d22319746c5..8892f27b84e9 100644 --- a/core/src/services/d1/model.rs +++ b/core/src/services/d1/model.rs @@ -42,33 +42,31 @@ impl D1Response { if !response.success { return Err(Error::new( crate::ErrorKind::Unexpected, - &String::from_utf8_lossy(&bs).into_owned(), + &String::from_utf8_lossy(bs), )); } Ok(response) } - pub fn get_result(&self, path: &str) -> Option> { - if self.result.len() > 0 { - let result = &self.result[0]; - if result.results.len() > 0 { - let result = &result.results[0]; - let value = result.get(path); - match value { - Some(Value::Array(s)) => { - let mut v = Vec::new(); - for i in s { - if let Value::Number(n) = i { - v.push(n.as_u64().unwrap() as u8); - } - } - return Some(v); + pub fn get_result(&self, key: &str) -> Option> { + if self.result.is_empty() || self.result[0].results.is_empty() { + return None; + } + let result = &self.result[0].results[0]; + let value = result.get(key); + + match value { + Some(Value::Array(s)) => { + let mut v = Vec::new(); + for i in s { + if let Value::Number(n) = i { + v.push(n.as_u64().unwrap() as u8); } - _ => return None, } + return Some(v); } + _ => return None, } - None } } From e61c99551fd0809cf1549912447a360c3c06c345 Mon Sep 17 00:00:00 2001 From: taobo Date: Fri, 13 Oct 2023 16:17:52 +0800 Subject: [PATCH 7/7] fix: let clippy happy --- core/src/services/d1/model.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/services/d1/model.rs b/core/src/services/d1/model.rs index 8892f27b84e9..c086af0d52f4 100644 --- a/core/src/services/d1/model.rs +++ b/core/src/services/d1/model.rs @@ -63,9 +63,9 @@ impl D1Response { v.push(n.as_u64().unwrap() as u8); } } - return Some(v); + Some(v) } - _ => return None, + _ => None, } } }