Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion Framework/Built_In_Automation/Database/BuiltInFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
DB_ODBC_DRIVER = "odbc_driver"
DB_SESSION = "session"
DB_ODBC_UTF8 = "odbc: enable utf-8 encoding"
DB_WAREHOUSE = "warehouse"
DB_SCHEMA = "schema"
DB_ACCOUNT = "account"


# [NON ACTION]
Expand Down Expand Up @@ -102,6 +105,16 @@ def find_odbc_driver(db_type="postgresql"):

def handle_db_exception(sModuleInfo, e):
import pyodbc

# Handle Snowflake exceptions
try:
import snowflake.connector.errors as snowflake_errors
if isinstance(e, snowflake_errors.Error):
traceback.print_exc()
CommonUtil.ExecLog(sModuleInfo, f"Snowflake Error: {e}", 3)
return CommonUtil.Exception_Handler(e)
except ImportError:
pass # Snowflake connector not installed

if isinstance(e, pyodbc.DataError):
traceback.print_exc()
Expand Down Expand Up @@ -243,6 +256,26 @@ def db_get_connection(session_name):
host=db_host,
port=db_port
)
elif "snowflake" in db_type:
import snowflake.connector

# Get Snowflake-specific parameters
account = db_params.get(DB_ACCOUNT)
if not account:
account = db_host.replace('.snowflakecomputing.com', '') if '.snowflakecomputing.com' in db_host else db_host
Comment thread Dismissed
warehouse = db_params.get(DB_WAREHOUSE) or 'COMPUTE_WH'
schema = db_params.get(DB_SCHEMA) or 'PUBLIC'

# Connect to Snowflake
db_con = snowflake.connector.connect(
user=db_user_id,
password=db_password,
account=account,
database=db_name,
warehouse=warehouse,
schema=schema
)
CommonUtil.ExecLog(sModuleInfo, "Connected to Snowflake.", 1)
elif "oracle" in db_type:
import cx_Oracle

Expand Down Expand Up @@ -303,14 +336,17 @@ def connect_to_db(data_set):
This action just stores the different database specific configs into shared variables for use by other actions.
NOTE: The actual db connection does not happen here, connection to db is made inside the actions which require it.

db_type input parameter <type of db, ex: postgres, mysql>
db_type input parameter <type of db, ex: postgres, mysql, snowflake>
db_name input parameter <name of db, ex: zeuz_db>
db_user_id input parameter <user id of the os who have access to the db, ex: postgres>
db_password input parameter <password of db, ex: mydbpass-mY1-t23z>
db_host input parameter <host of db, ex: localhost, 127.0.0.1>
db_port input parameter <port of db, ex: 5432 for postgres by default>
sid optional parameter <sid of db, ex: 15321 for oracle by default>
service_name optional parameter <service_name of db, ex: 'somename' for oracle by default>
warehouse optional parameter <warehouse for Snowflake, ex: COMPUTE_WH>
schema optional parameter <schema for Snowflake, ex: PUBLIC>
account optional parameter <account identifier for Snowflake>
odbc_driver optional parameter <specify the odbc driver, optional, can be found from pyodbc.drivers()>
odbc: enable utf-8 encoding optional parameter true/false - optionally enable utf-8 encoding
connect to db database action Connect to a database
Expand All @@ -324,6 +360,7 @@ def connect_to_db(data_set):
try:
# Default values
db_type = db_name = db_user_id = db_password = db_host = db_port = db_sid = db_service_name = db_odbc_driver = db_params = None
db_warehouse = db_schema = db_account = None
db_enable_odbc_utf8 = True
session_name = "default"

Expand All @@ -349,6 +386,12 @@ def connect_to_db(data_set):
sr.Set_Shared_Variables(DB_ODBC_DRIVER,right.strip())
if left == DB_ODBC_UTF8:
db_enable_odbc_utf8 = CommonUtil.parse_value_into_object(right.strip()) == True
if left == DB_WAREHOUSE or left == "warehouse":
db_warehouse = right.strip()
if left == DB_SCHEMA:
db_schema = right.strip()
if left == DB_ACCOUNT:
db_account = right.strip()
if DB_SESSION in left:
session_name = right.strip()

Expand All @@ -363,6 +406,9 @@ def connect_to_db(data_set):
DB_SERVICE_NAME: db_service_name,
DB_ODBC_DRIVER: db_odbc_driver,
DB_ODBC_UTF8: db_enable_odbc_utf8,
DB_WAREHOUSE: db_warehouse,
DB_SCHEMA: db_schema,
DB_ACCOUNT: db_account,
}

if sr.Test_Shared_Variables('db_sessions'):
Expand Down Expand Up @@ -426,6 +472,10 @@ def db_select(data_set):

# Get db_cursor and execute
db_con = db_get_connection(session_name)
if db_con == "zeuz_failed":
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
return "zeuz_failed"

with db_con:
with db_con.cursor() as db_cursor:
db_cursor.execute(query)
Expand Down Expand Up @@ -539,6 +589,10 @@ def select_from_db(data_set):

# Get db_cursor and execute
db_con = db_get_connection(session_name)
if db_con == "zeuz_failed":
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
return "zeuz_failed"

with db_con:
with db_con.cursor() as db_cursor:
db_cursor.execute(query)
Expand Down Expand Up @@ -627,6 +681,10 @@ def insert_into_db(data_set):
CommonUtil.ExecLog(sModuleInfo, "Executing query:\n%s." % query, 1)

db_con = db_get_connection(session_name)
if db_con == "zeuz_failed":
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
return "zeuz_failed"

with db_con:
with db_con.cursor() as db_cursor:
db_cursor.execute(query)
Expand Down Expand Up @@ -699,6 +757,10 @@ def delete_from_db(data_set):

# Get db_cursor and execute
db_con = db_get_connection(session_name)
if db_con == "zeuz_failed":
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
return "zeuz_failed"

with db_con:
with db_con.cursor() as db_cursor:
db_cursor.execute(query)
Expand Down Expand Up @@ -784,6 +846,10 @@ def update_into_db(data_set):

# Get db_cursor and execute
db_con = db_get_connection(session_name)
if db_con == "zeuz_failed":
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
return "zeuz_failed"

with db_con:
with db_con.cursor() as db_cursor:
db_cursor.execute(query)
Expand Down Expand Up @@ -849,6 +915,10 @@ def db_non_query(data_set):

# Get db_cursor and execute
db_con = db_get_connection(session_name)
if db_con == "zeuz_failed":
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
return "zeuz_failed"

with db_con:
with db_con.cursor() as db_cursor:
db_cursor.execute(query)
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ dependencies = [
"xvfbwrapper>=0.2.9 ; sys_platform == 'linux'",
"pyodbc>=5.2.0",
"psycopg2-binary>=2.9.10",
"cryptography==42.0.8",
"cryptography>=42.0.8",
"snowflake-connector-python>=3.12.0",
"pyopenssl>=23.0.0",
"pipdeptree>=2.26.1",
"axe-selenium-python>=2.1.6",
"filelock>=3.20.0",
Expand Down
Loading
Loading