diff --git a/core/src/services/huggingface/backend.rs b/core/src/services/huggingface/backend.rs index c5b5e4c8b806..96ab0716363c 100644 --- a/core/src/services/huggingface/backend.rs +++ b/core/src/services/huggingface/backend.rs @@ -106,6 +106,17 @@ impl HuggingfaceBuilder { } self } + + /// configure the Hub base url. You might want to set this variable if your + /// organization is using a Private Hub https://huggingface.co/enterprise + /// + /// Default is "https://huggingface.co" + pub fn endpoint(mut self, endpoint: &str) -> Self { + if !endpoint.is_empty() { + self.config.endpoint = Some(endpoint.to_string()); + } + self + } } impl Builder for HuggingfaceBuilder { @@ -151,6 +162,20 @@ impl Builder for HuggingfaceBuilder { let token = self.config.token.as_ref().cloned(); + let endpoint = match &self.config.endpoint { + Some(endpoint) => endpoint.clone(), + None => { + // Try to read from HF_ENDPOINT env var which is used + // by the official huggingface clients. + if let Ok(env_endpoint) = std::env::var("HF_ENDPOINT") { + env_endpoint + } else { + "https://huggingface.co".to_string() + } + } + }; + debug!("backend use endpoint: {}", &endpoint); + Ok(HuggingfaceBackend { core: Arc::new(HuggingfaceCore { info: { @@ -158,14 +183,10 @@ impl Builder for HuggingfaceBuilder { am.set_scheme(HUGGINGFACE_SCHEME) .set_native_capability(Capability { stat: true, - read: true, - list: true, list_with_recursive: true, - shared: true, - ..Default::default() }); am.into() @@ -175,6 +196,7 @@ impl Builder for HuggingfaceBuilder { revision, root, token, + endpoint, }), }) } diff --git a/core/src/services/huggingface/config.rs b/core/src/services/huggingface/config.rs index 096490560560..6b4034bfbfc8 100644 --- a/core/src/services/huggingface/config.rs +++ b/core/src/services/huggingface/config.rs @@ -50,6 +50,10 @@ pub struct HuggingfaceConfig { /// /// This is optional. pub token: Option, + /// Endpoint of the Huggingface Hub. + /// + /// Default is "https://huggingface.co". + pub endpoint: Option, } impl Debug for HuggingfaceConfig { diff --git a/core/src/services/huggingface/core.rs b/core/src/services/huggingface/core.rs index 5f3a65dff399..e852d0a6ccfb 100644 --- a/core/src/services/huggingface/core.rs +++ b/core/src/services/huggingface/core.rs @@ -22,12 +22,17 @@ use bytes::Bytes; use http::Request; use http::Response; use http::header; +use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode}; use serde::Deserialize; use super::backend::RepoType; use crate::raw::*; use crate::*; +fn percent_encode_revision(revision: &str) -> String { + utf8_percent_encode(revision, NON_ALPHANUMERIC).to_string() +} + pub struct HuggingfaceCore { pub info: Arc, @@ -36,6 +41,7 @@ pub struct HuggingfaceCore { pub revision: String, pub root: String, pub token: Option, + pub endpoint: String, } impl Debug for HuggingfaceCore { @@ -45,6 +51,7 @@ impl Debug for HuggingfaceCore { .field("repo_id", &self.repo_id) .field("revision", &self.revision) .field("root", &self.root) + .field("endpoint", &self.endpoint) .finish_non_exhaustive() } } @@ -57,12 +64,16 @@ impl HuggingfaceCore { let url = match self.repo_type { RepoType::Model => format!( - "https://huggingface.co/api/models/{}/paths-info/{}", - &self.repo_id, &self.revision + "{}/api/models/{}/paths-info/{}", + &self.endpoint, + &self.repo_id, + percent_encode_revision(&self.revision) ), RepoType::Dataset => format!( - "https://huggingface.co/api/datasets/{}/paths-info/{}", - &self.repo_id, &self.revision + "{}/api/datasets/{}/paths-info/{}", + &self.endpoint, + &self.repo_id, + percent_encode_revision(&self.revision) ), }; @@ -92,15 +103,17 @@ impl HuggingfaceCore { let mut url = match self.repo_type { RepoType::Model => format!( - "https://huggingface.co/api/models/{}/tree/{}/{}?expand=True", + "{}/api/models/{}/tree/{}/{}?expand=True", + &self.endpoint, &self.repo_id, - &self.revision, + percent_encode_revision(&self.revision), percent_encode_path(&p) ), RepoType::Dataset => format!( - "https://huggingface.co/api/datasets/{}/tree/{}/{}?expand=True", + "{}/api/datasets/{}/tree/{}/{}?expand=True", + &self.endpoint, &self.repo_id, - &self.revision, + percent_encode_revision(&self.revision), percent_encode_path(&p) ), }; @@ -134,15 +147,17 @@ impl HuggingfaceCore { let url = match self.repo_type { RepoType::Model => format!( - "https://huggingface.co/{}/resolve/{}/{}", + "{}/{}/resolve/{}/{}", + &self.endpoint, &self.repo_id, - &self.revision, + percent_encode_revision(&self.revision), percent_encode_path(&p) ), RepoType::Dataset => format!( - "https://huggingface.co/datasets/{}/resolve/{}/{}", + "{}/datasets/{}/resolve/{}/{}", + &self.endpoint, &self.repo_id, - &self.revision, + percent_encode_revision(&self.revision), percent_encode_path(&p) ), };