diff --git a/server/bleep/sqlx-data.json b/server/bleep/sqlx-data.json index b8673de415..c7a6f1485b 100644 --- a/server/bleep/sqlx-data.json +++ b/server/bleep/sqlx-data.json @@ -278,6 +278,30 @@ }, "query": "SELECT id, index_status, name, url, favicon, description, modified_at FROM docs WHERE id = ?" }, + "291848ee7ef54ea247a2f83e89d2dd8e96024ba2fe65d0e443ecc94942aa6fa9": { + "describe": { + "columns": [ + { + "name": "repo_ref", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "branch", + "ordinal": 1, + "type_info": "Text" + } + ], + "nullable": [ + false, + true + ], + "parameters": { + "Right": 1 + } + }, + "query": "SELECT repo_ref, branch\n FROM project_repos\n WHERE project_id = ?" + }, "2d33f9119b3b56c55378080c5c95aa91fcb495ceb39caaa4f2541d8b2aa408ae": { "describe": { "columns": [], @@ -1074,6 +1098,24 @@ }, "query": "INSERT INTO studio_snapshots (studio_id, context, doc_context, messages)\n VALUES (?, ?, ?, ?)" }, + "95eaff0006df7f12604154c46113197e9440f06520672e2d7409cd0b831d83c2": { + "describe": { + "columns": [ + { + "name": "repo_ref", + "ordinal": 0, + "type_info": "Text" + } + ], + "nullable": [ + false + ], + "parameters": { + "Right": 1 + } + }, + "query": "SELECT repo_ref\n FROM project_repos\n WHERE project_id = ?" + }, "9db35f3045790fbd63f1efc4a96e5a7234f09cc513323320fd145146b03cce2b": { "describe": { "columns": [ @@ -1378,24 +1420,6 @@ }, "query": "UPDATE studio_snapshots SET messages = ? WHERE id = ?" }, - "dcb7f9427283203bce10fe7d618057ef3eab5f6af2277e7a1ac8ba050609894d": { - "describe": { - "columns": [ - { - "name": "url", - "ordinal": 0, - "type_info": "Text" - } - ], - "nullable": [ - false - ], - "parameters": { - "Right": 1 - } - }, - "query": "SELECT url FROM docs WHERE id = ?" - }, "deae1c1c2619ec6e76e0b5fcc526bbabbc1d66642efc6158a793068221ebd019": { "describe": { "columns": [ diff --git a/server/bleep/src/query/execute.rs b/server/bleep/src/query/execute.rs index 5e4739db90..e1fcd715fe 100644 --- a/server/bleep/src/query/execute.rs +++ b/server/bleep/src/query/execute.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, collections::{HashMap, HashSet}, sync::Arc, }; @@ -12,6 +13,7 @@ use crate::{ }, repo::RepoRef, snippet::{HighlightedString, SnippedFile, Snipper}, + Application, }; use anyhow::{bail, Result}; @@ -231,11 +233,113 @@ pub trait ExecuteQuery { } impl ApiQuery { - pub async fn query(self: Arc, indexes: Arc) -> Result { - let query = self.q.clone(); - let compiled = parser::parse(&query)?; - tracing::debug!("compiled query as {compiled:?}"); - self.query_with(indexes, compiled).await + pub async fn query(self: Arc, app: &Application) -> Result { + let raw_query = self.q.clone(); + let queries = self + .restrict_queries(parser::parse(&raw_query)?, app) + .await?; + tracing::debug!("compiled query as {queries:?}"); + self.query_with(Arc::clone(&app.indexes), queries).await + } + + /// This restricts a set of input parser queries. + /// + /// We trim down the input by: + /// + /// 1. Discarding all queries that reference repos not in the queried project + /// 2. Regenerating more specific queries for those without repo restrictions, such that there + /// is a new query generated per repo that exists in the project. + /// + /// The idea here is to allow us to restrict the possible input space of queried documents to + /// be more specific as required by the project state. + /// + /// The `subset` flag indicates whether repo name matching is whole-string, or whether the + /// string must only be a substring of an existing repo. This is useful in autocomplete + /// scenarios, where we want to restrict queries such that they are not fully typed out. + pub async fn restrict_queries<'a>( + &self, + queries: impl IntoIterator>, + app: &Application, + ) -> Result>> { + let repo_branches = sqlx::query! { + "SELECT repo_ref, branch + FROM project_repos + WHERE project_id = ?", + self.project_id, + } + .fetch_all(&*app.sql) + .await? + .into_iter() + .map(|row| { + ( + row.repo_ref.parse::().unwrap().indexed_name(), + row.branch, + ) + }) + .collect::>(); + + let mut out = Vec::new(); + + for q in queries { + if let Some(r) = q.repo_str() { + // The branch that this project has loaded this repo with. + let project_branch = repo_branches.get(&r).map(Option::as_ref).flatten(); + + // If the branch doesn't match what we expect, drop the query. + if q.branch_str().as_ref() == project_branch { + out.push(q); + } + } else { + for (r, b) in &repo_branches { + out.push(parser::Query { + repo: Some(parser::Literal::from(r)), + branch: b.as_ref().map(|b| parser::Literal::from(b)), + ..q.clone() + }); + } + } + } + + Ok(out) + } + + /// This restricts a set of input repo-only queries. + /// + /// This is useful for autocomplete queries, which are effectively just `repo:foo`, where the + /// repo name may be partially written. + pub async fn restrict_repo_queries<'a>( + &self, + queries: impl IntoIterator>, + app: &Application, + ) -> Result>> { + let repo_refs = sqlx::query! { + "SELECT repo_ref + FROM project_repos + WHERE project_id = ?", + self.project_id, + } + .fetch_all(&*app.sql) + .await? + .into_iter() + .map(|row| row.repo_ref.parse::().unwrap().indexed_name()) + .collect::>(); + + let mut out = Vec::new(); + + for q in queries { + if let Some(r) = q.repo_str() { + for m in repo_refs.iter().filter(|r2| r2.contains(&r)) { + out.push(parser::Query { + repo: Some(parser::Literal::from(m)), + ..Default::default() + }); + } + } + } + + out.dedup(); + + Ok(out) } pub async fn query_with( diff --git a/server/bleep/src/query/parser.rs b/server/bleep/src/query/parser.rs index 0504fa2952..4f82efc00a 100644 --- a/server/bleep/src/query/parser.rs +++ b/server/bleep/src/query/parser.rs @@ -168,6 +168,24 @@ impl<'a> Query<'a> { } } +impl<'a> Query<'a> { + /// Get the `repo` value for this query as a plain string. + pub fn repo_str(&self) -> Option { + self.repo + .as_ref() + .and_then(Literal::as_plain) + .map(Cow::into_owned) + } + + /// Get the `branch` value for this query as a plain string. + pub fn branch_str(&self) -> Option { + self.branch + .as_ref() + .and_then(Literal::as_plain) + .map(Cow::into_owned) + } +} + impl<'a> Target<'a> { /// Get the inner literal for this target, regardless of the variant. pub fn literal_mut(&'a mut self) -> &mut Literal<'a> { diff --git a/server/bleep/src/webserver.rs b/server/bleep/src/webserver.rs index 1126391f6b..369a70a092 100644 --- a/server/bleep/src/webserver.rs +++ b/server/bleep/src/webserver.rs @@ -50,8 +50,6 @@ pub async fn start(app: Application) -> anyhow::Result<()> { let mut api = Router::new() .route("/config", get(config::get).put(config::put)) - // autocomplete - .route("/autocomplete", get(autocomplete::handle)) // indexing .route("/index", get(index::handle)) // repo management @@ -119,6 +117,10 @@ pub async fn start(app: Application) -> anyhow::Result<()> { get(conversation::get).delete(conversation::delete), ) .route("/projects/:project_id/q", get(query::handle)) + .route( + "/projects/:project_id/autocomplete", + get(autocomplete::handle), + ) .route("/projects/:project_id/search/path", get(search::fuzzy_path)) .route("/projects/:project_id/answer/vote", post(answer::vote)) .route("/projects/:project_id/answer", get(answer::answer)) diff --git a/server/bleep/src/webserver/autocomplete.rs b/server/bleep/src/webserver/autocomplete.rs index b6afd07085..523429d351 100644 --- a/server/bleep/src/webserver/autocomplete.rs +++ b/server/bleep/src/webserver/autocomplete.rs @@ -1,19 +1,21 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use super::prelude::*; use crate::{ - indexes::{ - reader::{ContentReader, FileReader, RepoReader}, - Indexes, - }, + indexes::reader::{ContentReader, FileReader, RepoReader}, query::{ execute::{ApiQuery, ExecuteQuery, QueryResult}, languages, parser, parser::{Literal, Target}, }, + Application, }; -use axum::{extract::Query, response::IntoResponse as IntoAxumResponse, Extension}; +use axum::{ + extract::{Path, Query}, + response::IntoResponse as IntoAxumResponse, + Extension, +}; use futures::{stream, StreamExt, TryStreamExt}; use serde::Serialize; @@ -36,12 +38,15 @@ pub struct AutocompleteParams { pub(super) async fn handle( Query(mut api_params): Query, Query(ac_params): Query, - Extension(indexes): Extension>, + Path(project_id): Path, + Extension(app): Extension, ) -> Result { // Override page_size and set to low value api_params.page = 0; api_params.page_size = 8; + api_params.project_id = project_id; + let mut partial_lang = None; let mut has_target = false; @@ -114,17 +119,25 @@ pub(super) async fn handle( ); } + // NB: This restricts queries in a repo-specific way. This might need to be generalized if + // we still use the other autocomplete fields. + let repo_queries = api_params + .restrict_repo_queries(queries.clone(), &app) + .await?; + + dbg!(&queries, &repo_queries); + let mut engines = vec![]; if ac_params.content { - engines.push(ContentReader.execute(&indexes.file, &queries, &api_params)); + engines.push(ContentReader.execute(&app.indexes.file, &queries, &api_params)); } if ac_params.repo { - engines.push(RepoReader.execute(&indexes.repo, &queries, &api_params)); + engines.push(RepoReader.execute(&app.indexes.repo, &repo_queries, &api_params)); } if ac_params.file { - engines.push(FileReader.execute(&indexes.file, &queries, &api_params)); + engines.push(FileReader.execute(&app.indexes.file, &queries, &api_params)); } let (langs, list) = stream::iter(engines) diff --git a/server/bleep/src/webserver/query.rs b/server/bleep/src/webserver/query.rs index ebd0105626..498e6a80d2 100644 --- a/server/bleep/src/webserver/query.rs +++ b/server/bleep/src/webserver/query.rs @@ -1,4 +1,4 @@ -use axum::extract::{Path, State}; +use axum::extract::Path; use super::prelude::*; use crate::{db::QueryLog, query::execute::ApiQuery, Application}; @@ -6,15 +6,14 @@ use crate::{db::QueryLog, query::execute::ApiQuery, Application}; pub(super) async fn handle( Path(project_id): Path, Query(mut api_params): Query, - Extension(indexes): Extension>, - State(app): State, + Extension(app): Extension, ) -> impl IntoResponse { QueryLog::new(&app.sql).insert(&api_params.q).await?; api_params.project_id = project_id; Arc::new(api_params) - .query(indexes) + .query(&app) .await .map(json) .map_err(super::Error::from)