Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg-r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ Imports:
whisker
Suggests:
bsicons,
dbplyr,
dplyr,
DT,
palmerpenguins,
RSQLite,
Expand Down
1 change: 1 addition & 0 deletions pkg-r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export(DBISource)
export(DataFrameSource)
export(DataSource)
export(QueryChat)
export(TblLazySource)
export(querychat)
export(querychat_app)
export(querychat_data_source)
Expand Down
209 changes: 194 additions & 15 deletions pkg-r/R/DataSource.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ DataSource <- R6::R6Class(
#'
#' @description
#' A DataSource implementation that wraps a data frame using DuckDB for SQL
#' query execution.
#'
#' @details
#' This class creates an in-memory DuckDB connection and registers the provided
#' data frame as a table. All SQL queries are executed against this DuckDB table.
#' query execution. This class creates an in-memory DuckDB connection and
#' registers the provided data frame as a table. All SQL queries are executed
#' against this DuckDB table.
#'
#' @export
#' @examples
Expand Down Expand Up @@ -222,11 +220,8 @@ DataFrameSource <- R6::R6Class(
#'
#' @description
#' A DataSource implementation for DBI database connections (SQLite, PostgreSQL,
#' MySQL, etc.).
#'
#' @details
#' This class wraps a DBI connection and provides SQL query execution against
#' a specified table in the database.
#' MySQL, etc.). This class wraps a DBI connection and provides SQL query
#' execution against a single table in the database.
#'
#' @export
#' @examples
Expand Down Expand Up @@ -378,6 +373,182 @@ DBISource <- R6::R6Class(
)


#' Data Source: Lazy Tibble
#'
#' @description
#' A DataSource implementation for lazy tibbles connected to databases via
#' [dbplyr::tbl_sql()] or [dplyr::sql()].
#'
#' @examplesIf rlang::is_interactive() && rlang::is_installed("dbplyr") && rlang::is_installed("dplyr") && rlang::is_installed("duckdb")
#' con <- DBI::dbConnect(duckdb::duckdb())
#' DBI::dbWriteTable(con, "mtcars", mtcars)
#'
#' mtcars_source <- TblLazySource$new(tbl(con, "mtcars"))
#' mtcars_source$get_db_type() # "DuckDB"
#'
#' result <- mtcars_source$execute_query("SELECT * FROM mtcars WHERE cyl > 4")
#'
#' # Note, the result is not the *full* data frame, but a lazy SQL tibble
#' result
#'
#' # You can chain this result into a dplyr pipeline
#' dplyr::count(result, cyl, gear)
#'
#' # Or collect the entire data frame into local memory
#' dplyr::collect(result)
#'
#' # Finally, clean up when done with the database (closes the DB connection)
#' mtcars_source$cleanup()
#'
#' @export
TblLazySource <- R6::R6Class(
"TblLazySource",
inherit = DBISource,
private = list(
tbl = NULL,
tbl_cte = NULL
),
public = list(
#' @field table_name Name of the table to be used in SQL queries
table_name = NULL,

#' @description
#' Create a new TblLazySource
#'
#' @param tbl A [dbplyr::tbl_sql()] (or lazy tibble via [dplyr::tbl()]).
#' @param table_name Name of the table in the database. Can be a character
#' string, or will be inferred from the `tbl` argument, if possible.
#' @return A new TblLazySource object
#' @examplesIf rlang::is_interactive() && rlang::is_installed("dbplyr") && rlang::is_installed("dplyr") && rlang::is_installed("RSQLite")
#' conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
#' DBI::dbWriteTable(conn, "mtcars", mtcars)
#' source <- TblLazySource$new(dplyr::tbl(con, "mtcars"))
initialize = function(tbl, table_name = missing_arg()) {
check_installed("dbplyr")
check_installed("dplyr")

if (!inherits(tbl, "tbl_sql")) {
cli::cli_abort(
"{.arg tbl} must be a lazy tibble connected to a database, not {.obj_type_friendly {tbl}}"
)
}

private$conn <- dbplyr::remote_con(tbl)
private$tbl <- tbl

# Collect various signals to infer the table name
obj_name <- deparse1(substitute(tbl))

# Get the exact table name, if tbl directly references a single table
remote_name <- dbplyr::remote_name(private$tbl)

use_cte <- FALSE

if (!is_missing(table_name)) {
check_sql_table_name(table_name)
self$table_name <- table_name
use_cte <- identical(table_name, remote_name %||% remote_table)
} else if (!is.null(remote_name)) {
# Remote name is non-NULL when it points to a table, so we use that next
self$table_name <- remote_name
use_cte <- FALSE
} else if (is_valid_sql_table_name(obj_name)) {
self$table_name <- obj_name
use_cte <- TRUE
} else {
id <- as.integer(runif(1) * 1e6)
self$table_name <- sprintf("querychat_cte_%d", id)
use_cte <- TRUE
}

if (use_cte) {
# We received a complicated tbl expression, we'll have to use a CTE
private$tbl_cte <- dbplyr::remote_query(private$tbl)
}
},

#' @description
#' Get the database type
#'
#' @return A string describing the database type (e.g., "DuckDB", "SQLite")
get_db_type = function() {
super$get_db_type()
},

#' @description
#' Get schema information about the table
#'
#' @param categorical_threshold Maximum number of unique values for a text
#' column to be considered categorical
#' @return A string containing schema information formatted for LLM prompts
get_schema = function(categorical_threshold = 20) {
get_schema_impl(
private$conn,
self$table_name,
categorical_threshold,
columns = colnames(private$tbl),
prep_query = self$prep_query
)
},

#' @description
#' Execute a SQL query and return results
#'
#' @param query SQL query string to execute
#' @return A data frame containing query results
execute_query = function(query) {
sql_query <- self$prep_query(query)
dplyr::tbl(private$conn, dplyr::sql(sql_query))
},

#' @description
#' Test a SQL query by fetching only one row
#'
#' @param query SQL query string to test
#' @return A data frame containing one row of results (or empty if no matches)
test_query = function(query) {
super$test_query(self$prep_query(query))
},

#' @description
#' Prepare a generic `SELECT * FROM ____` query to work with the SQL tibble
#'
#' @param query SQL query as a string
#' @return A complete SQL query string
prep_query = function(query) {
check_string(query)

if (is.null(private$tbl_cte)) {
return(query)
}

sprintf(
"WITH %s AS (\n%s\n)\n%s",
DBI::dbQuoteIdentifier(private$conn, self$table_name),
private$tbl_cte,
query
)
},

#' @description
#' Get the unfiltered data as a SQL tibble
#'
#' @return A [dbplyr::tbl_sql()] containing the original, unfiltered data
get_data = function() {
private$tbl
},

#' @description
#' Clean up resources (close connections, etc.)
#'
#' @return NULL (invisibly)
cleanup = function() {
super$cleanup()
}
)
)


# Helper Functions -------------------------------------------------------------

#' Check if object is a DataSource
Expand All @@ -390,9 +561,17 @@ is_data_source <- function(x) {
}


get_schema_impl <- function(conn, table_name, categorical_threshold = 20) {
get_schema_impl <- function(
conn,
table_name,
categorical_threshold = 20,
columns = NULL,
prep_query = identity
) {
check_function(prep_query)

# Get column information
columns <- DBI::dbListFields(conn, table_name)
columns <- columns %||% DBI::dbListFields(conn, table_name)

schema_lines <- c(
paste("Table:", DBI::dbQuoteIdentifier(conn, table_name)),
Expand All @@ -410,7 +589,7 @@ get_schema_impl <- function(conn, table_name, categorical_threshold = 20) {
DBI::dbQuoteIdentifier(conn, table_name),
" LIMIT 1"
)
sample_data <- DBI::dbGetQuery(conn, sample_query)
sample_data <- DBI::dbGetQuery(conn, prep_query(sample_query))

for (col in columns) {
col_class <- class(sample_data[[col]])[1]
Expand Down Expand Up @@ -460,7 +639,7 @@ get_schema_impl <- function(conn, table_name, categorical_threshold = 20) {
" FROM ",
DBI::dbQuoteIdentifier(conn, table_name)
)
result <- DBI::dbGetQuery(conn, stats_query)
result <- DBI::dbGetQuery(conn, prep_query(stats_query))
if (nrow(result) > 0) {
column_stats <- as.list(result[1, ])
}
Expand Down Expand Up @@ -505,7 +684,7 @@ get_schema_impl <- function(conn, table_name, categorical_threshold = 20) {
" IS NOT NULL ORDER BY ",
DBI::dbQuoteIdentifier(conn, col_name)
)
result <- DBI::dbGetQuery(conn, cat_query)
result <- DBI::dbGetQuery(conn, prep_query(cat_query))
if (nrow(result) > 0) {
categorical_values[[col_name]] <- result[[1]]
}
Expand Down
25 changes: 20 additions & 5 deletions pkg-r/R/QueryChat.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ QueryChat <- R6::R6Class(
check_string(prompt_template, allow_null = TRUE)
check_bool(cleanup, allow_na = TRUE)

if (is_missing(table_name) && is.data.frame(data_source)) {
table_name <- deparse1(substitute(data_source))
if (is_missing(table_name)) {
if (is.data.frame(data_source) || inherits(data_source, "tbl_sql")) {
table_name <- deparse1(substitute(data_source))
}
}

private$.data_source <- normalize_data_source(data_source, table_name)
Expand Down Expand Up @@ -338,8 +340,15 @@ QueryChat <- R6::R6Class(
})

output$dt <- DT::renderDT({
df <- qc_vals$df()
if (inherits(df, "tbl_sql")) {
# Materialize the query to get a data frame, {dplyr} guaranteed by
# TblLazySource interface
df <- dplyr::collect(df)
}

DT::datatable(
qc_vals$df(),
df,
fillContainer = TRUE,
options = list(pageLength = 25, scrollX = TRUE)
)
Expand Down Expand Up @@ -631,8 +640,10 @@ querychat <- function(
prompt_template = NULL,
cleanup = NA
) {
if (is_missing(table_name) && is.data.frame(data_source)) {
table_name <- deparse1(substitute(data_source))
if (is_missing(table_name)) {
if (is.data.frame(data_source) || inherits(data_source, "tbl_sql")) {
table_name <- deparse1(substitute(data_source))
}
}

QueryChat$new(
Expand Down Expand Up @@ -701,6 +712,10 @@ normalize_data_source <- function(data_source, table_name) {
return(DataFrameSource$new(data_source, table_name))
}

if (inherits(data_source, "tbl_lazy")) {
return(TblLazySource$new(data_source, table_name))
}

if (inherits(data_source, "DBIConnection")) {
return(DBISource$new(data_source, table_name))
}
Expand Down
6 changes: 5 additions & 1 deletion pkg-r/R/utils-check.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ check_sql_table_name <- function(
check_string(x, allow_null = allow_null, arg = arg, call = call)

# Then validate SQL table name pattern
if (!grepl("^[a-zA-Z][a-zA-Z0-9_]*$", x)) {
if (!is_valid_sql_table_name(x)) {
cli::cli_abort(
c(
"{.arg {arg}} must be a valid SQL table name",
Expand All @@ -45,3 +45,7 @@ check_sql_table_name <- function(

invisible(NULL)
}

is_valid_sql_table_name <- function(x) {
grepl("^[a-zA-Z][a-zA-Z0-9_]*$", x)
}
7 changes: 2 additions & 5 deletions pkg-r/man/DBISource.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 3 additions & 5 deletions pkg-r/man/DataFrameSource.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading