Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions core/src/services/huggingface/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ impl HuggingfaceBuilder {
}
self
}

/// configure the Hub base url. You might want to set this variable if your
/// organization is using a Private Hub https://huggingface.co/enterprise
///
/// Default is "https://huggingface.co"
pub fn endpoint(mut self, endpoint: &str) -> Self {
if !endpoint.is_empty() {
self.config.endpoint = Some(endpoint.to_string());
}
self
}
}

impl Builder for HuggingfaceBuilder {
Expand Down Expand Up @@ -151,21 +162,31 @@ impl Builder for HuggingfaceBuilder {

let token = self.config.token.as_ref().cloned();

let endpoint = match &self.config.endpoint {
Some(endpoint) => endpoint.clone(),
None => {
// Try to read from HF_ENDPOINT env var which is used
// by the official huggingface clients.
if let Ok(env_endpoint) = std::env::var("HF_ENDPOINT") {
env_endpoint
} else {
"https://huggingface.co".to_string()
}
}
};
debug!("backend use endpoint: {}", &endpoint);

Ok(HuggingfaceBackend {
core: Arc::new(HuggingfaceCore {
info: {
let am = AccessorInfo::default();
am.set_scheme(HUGGINGFACE_SCHEME)
.set_native_capability(Capability {
stat: true,

read: true,

list: true,
list_with_recursive: true,

shared: true,

..Default::default()
});
am.into()
Expand All @@ -175,6 +196,7 @@ impl Builder for HuggingfaceBuilder {
revision,
root,
token,
endpoint,
}),
})
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/services/huggingface/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ pub struct HuggingfaceConfig {
///
/// This is optional.
pub token: Option<String>,
/// Endpoint of the Huggingface Hub.
///
/// Default is "https://huggingface.co".
pub endpoint: Option<String>,
}

impl Debug for HuggingfaceConfig {
Expand Down
39 changes: 27 additions & 12 deletions core/src/services/huggingface/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@ use bytes::Bytes;
use http::Request;
use http::Response;
use http::header;
use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
use serde::Deserialize;

use super::backend::RepoType;
use crate::raw::*;
use crate::*;

fn percent_encode_revision(revision: &str) -> String {
utf8_percent_encode(revision, NON_ALPHANUMERIC).to_string()
}

pub struct HuggingfaceCore {
pub info: Arc<AccessorInfo>,

Expand All @@ -36,6 +41,7 @@ pub struct HuggingfaceCore {
pub revision: String,
pub root: String,
pub token: Option<String>,
pub endpoint: String,
}

impl Debug for HuggingfaceCore {
Expand All @@ -45,6 +51,7 @@ impl Debug for HuggingfaceCore {
.field("repo_id", &self.repo_id)
.field("revision", &self.revision)
.field("root", &self.root)
.field("endpoint", &self.endpoint)
.finish_non_exhaustive()
}
}
Expand All @@ -57,12 +64,16 @@ impl HuggingfaceCore {

let url = match self.repo_type {
RepoType::Model => format!(
"https://huggingface.co/api/models/{}/paths-info/{}",
&self.repo_id, &self.revision
"{}/api/models/{}/paths-info/{}",
&self.endpoint,
&self.repo_id,
percent_encode_revision(&self.revision)
),
RepoType::Dataset => format!(
"https://huggingface.co/api/datasets/{}/paths-info/{}",
&self.repo_id, &self.revision
"{}/api/datasets/{}/paths-info/{}",
&self.endpoint,
&self.repo_id,
percent_encode_revision(&self.revision)
),
};

Expand Down Expand Up @@ -92,15 +103,17 @@ impl HuggingfaceCore {

let mut url = match self.repo_type {
RepoType::Model => format!(
"https://huggingface.co/api/models/{}/tree/{}/{}?expand=True",
"{}/api/models/{}/tree/{}/{}?expand=True",
&self.endpoint,
&self.repo_id,
&self.revision,
percent_encode_revision(&self.revision),
percent_encode_path(&p)
),
RepoType::Dataset => format!(
"https://huggingface.co/api/datasets/{}/tree/{}/{}?expand=True",
"{}/api/datasets/{}/tree/{}/{}?expand=True",
&self.endpoint,
&self.repo_id,
&self.revision,
percent_encode_revision(&self.revision),
percent_encode_path(&p)
),
};
Expand Down Expand Up @@ -134,15 +147,17 @@ impl HuggingfaceCore {

let url = match self.repo_type {
RepoType::Model => format!(
"https://huggingface.co/{}/resolve/{}/{}",
"{}/{}/resolve/{}/{}",
&self.endpoint,
&self.repo_id,
&self.revision,
percent_encode_revision(&self.revision),
percent_encode_path(&p)
),
RepoType::Dataset => format!(
"https://huggingface.co/datasets/{}/resolve/{}/{}",
"{}/datasets/{}/resolve/{}/{}",
&self.endpoint,
&self.repo_id,
&self.revision,
percent_encode_revision(&self.revision),
percent_encode_path(&p)
),
};
Expand Down
Loading