Skip to content

Commit a79449d

Browse files
committed
Implemented support for dialects
1 parent f4e1c6c commit a79449d

File tree

8 files changed

+352
-107
lines changed

8 files changed

+352
-107
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "sql_docs"
3-
version = "1.1.0"
3+
version = "1.2.0"
44
edition = "2024"
55
description = "A crate for parsing comments from sql files and using them for documentation generation"
66
documentation = "https://docs.rs/sql_docs"
@@ -22,7 +22,7 @@ default = []
2222
fuzzing = []
2323

2424
[dependencies]
25-
sqlparser = { version = "0.60.0", git = "https://github.com/apache/datafusion-sqlparser-rs", branch = "main" }
25+
sqlparser = { git = "https://github.com/LucaCappelletti94/sqlparser-rs", branch = "main" }
2626

2727
[lints.rust]
2828
missing_docs = "forbid"

fuzz/Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/ast.rs

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ use std::path::{Path, PathBuf};
66

77
use sqlparser::{
88
ast::Statement,
9-
dialect::GenericDialect,
109
parser::{Parser, ParserError},
1110
};
1211

13-
use crate::source::SqlSource;
12+
use crate::{dialects::Dialects, source::SqlSource};
1413

1514
/// A single SQL file plus all [`Statement`].
1615
#[derive(Debug)]
@@ -30,9 +29,8 @@ impl ParsedSqlFile {
3029
///
3130
/// # Errors
3231
/// - Returns [`ParserError`] if parsing fails
33-
pub fn parse(file: SqlSource) -> Result<Self, ParserError> {
34-
let dialect = GenericDialect {};
35-
let statements = Parser::parse_sql(&dialect, file.content())?;
32+
pub fn parse(file: SqlSource, dialect: &Dialects) -> Result<Self, ParserError> {
33+
let statements = Parser::parse_sql(dialect.dialect(), file.content())?;
3634
Ok(Self { file, statements })
3735
}
3836

@@ -82,7 +80,10 @@ impl ParsedSqlFileSet {
8280
/// # Errors
8381
/// - [`ParserError`] is returned for any errors parsing
8482
pub fn parse_all(set: Vec<SqlSource>) -> Result<Self, ParserError> {
85-
let files = set.into_iter().map(ParsedSqlFile::parse).collect::<Result<Vec<_>, _>>()?;
83+
let files = set
84+
.into_iter()
85+
.map(|s| ParsedSqlFile::parse(s, &Dialects::default()))
86+
.collect::<Result<Vec<_>, _>>()?;
8687

8788
Ok(Self { files })
8889
}
@@ -110,7 +111,7 @@ mod tests {
110111
let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY);";
111112
fs::write(&file_path, sql)?;
112113
let sql_file = SqlSource::from_path(&file_path)?;
113-
let parsed = ParsedSqlFile::parse(sql_file)?;
114+
let parsed = ParsedSqlFile::parse(sql_file, &Dialects::Generic)?;
114115
assert_eq!(parsed.path(), Some(file_path.as_path()));
115116
assert_eq!(parsed.content(), sql);
116117
assert_eq!(parsed.statements().len(), 1);
@@ -147,4 +148,108 @@ mod tests {
147148
let _ = fs::remove_dir_all(&base);
148149
Ok(())
149150
}
151+
152+
#[test]
153+
fn parsed_sql_file_path_into_path_buf_round_trips() -> Result<(), Box<dyn std::error::Error>> {
154+
let base = env::temp_dir().join("parsed_sql_file_path_into_path_buf_round_trips");
155+
let _ = fs::remove_dir_all(&base);
156+
fs::create_dir_all(&base)?;
157+
let file_path = base.join("one.sql");
158+
let sql = "CREATE TABLE t (id INTEGER PRIMARY KEY);";
159+
fs::write(&file_path, sql)?;
160+
let sql_file = SqlSource::from_path(&file_path)?;
161+
let parsed = ParsedSqlFile::parse(sql_file, &Dialects::Generic)?;
162+
assert_eq!(parsed.path_into_path_buf(), Some(file_path));
163+
let _ = fs::remove_dir_all(&base);
164+
Ok(())
165+
}
166+
167+
#[test]
168+
fn parsed_sql_file_parse_postgres_handles_pg_function_syntax() -> Result<(), Box<dyn std::error::Error>> {
169+
let sql = r"
170+
CREATE OR REPLACE FUNCTION f()
171+
RETURNS SMALLINT
172+
LANGUAGE plpgsql
173+
SECURITY DEFINER
174+
STABLE
175+
AS $$
176+
BEGIN
177+
RETURN 4;
178+
END;
179+
$$;
180+
181+
CREATE TABLE t (id INTEGER PRIMARY KEY);
182+
";
183+
184+
let src = SqlSource::from_str(sql.to_owned(), None);
185+
let parsed = ParsedSqlFile::parse(src, &Dialects::PostgreSql)?;
186+
assert!(
187+
parsed.statements().len() >= 2,
188+
"expected at least 2 statements (function + table)"
189+
);
190+
assert!(
191+
parsed
192+
.statements()
193+
.iter()
194+
.any(|s| matches!(s, Statement::CreateTable { .. })),
195+
"expected at least one CreateTable statement"
196+
);
197+
Ok(())
198+
}
199+
200+
#[test]
201+
fn parsed_sql_file_set_parse_all_uses_default_dialect() -> Result<(), Box<dyn std::error::Error>> {
202+
let base = env::temp_dir().join("parsed_sql_file_set_parse_all_default_dialect");
203+
let _ = fs::remove_dir_all(&base);
204+
fs::create_dir_all(&base)?;
205+
206+
let file1 = base.join("one.sql");
207+
let file2 = base.join("two.sql");
208+
209+
fs::write(&file1, "CREATE TABLE t1 (id INTEGER PRIMARY KEY);")?;
210+
211+
let pg_sql = r"
212+
CREATE OR REPLACE FUNCTION f()
213+
RETURNS SMALLINT
214+
LANGUAGE plpgsql
215+
SECURITY DEFINER
216+
STABLE
217+
AS $$
218+
BEGIN
219+
RETURN 4;
220+
END;
221+
$$;
222+
223+
CREATE TABLE t2 (id INTEGER PRIMARY KEY);
224+
";
225+
fs::write(&file2, pg_sql)?;
226+
227+
let set = SqlSource::sql_sources(&base, &[])?;
228+
let parsed_set = ParsedSqlFileSet::parse_all(set)?;
229+
230+
assert_eq!(parsed_set.files().len(), 2);
231+
232+
for parsed in parsed_set.files() {
233+
assert!(
234+
parsed
235+
.statements()
236+
.iter()
237+
.any(|s| matches!(s, Statement::CreateTable { .. })),
238+
"expected CreateTable in parsed file; got statements: {:?}",
239+
parsed.statements()
240+
);
241+
}
242+
243+
let _ = fs::remove_dir_all(&base);
244+
Ok(())
245+
}
246+
247+
#[test]
248+
fn parsed_sql_file_parse_invalid_sql_returns_error() {
249+
let sql = "CREATE TABLE";
250+
let src = SqlSource::from_str(sql.to_owned(), None);
251+
let res = ParsedSqlFile::parse(src, &Dialects::Generic);
252+
assert!(res.is_err(), "expected parse to fail for invalid SQL");
253+
}
254+
150255
}

src/dialect.rs

Lines changed: 0 additions & 75 deletions
This file was deleted.

src/dialects.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//! Handles the selection of the `SQL` dialect to use for parsing
2+
3+
use sqlparser::dialect::{
4+
AnsiDialect, BigQueryDialect, ClickHouseDialect, DatabricksDialect, Dialect, DuckDbDialect,
5+
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, OracleDialect, PostgreSqlDialect,
6+
RedshiftSqlDialect, SQLiteDialect, SnowflakeDialect,
7+
};
8+
9+
/// Dialects supported by this crate.
10+
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
11+
pub enum Dialects {
12+
/// ANSI SQL dialect
13+
Ansi,
14+
/// Google `BigQuery` SQL dialect
15+
BigQuery,
16+
/// `ClickHouse` SQL dialect
17+
ClickHouse,
18+
/// Databricks SQL dialect
19+
Databricks,
20+
/// `DuckDB` SQL dialect
21+
DuckDb,
22+
/// Generic SQL dialect
23+
Generic,
24+
/// Apache Hive SQL dialect
25+
Hive,
26+
/// Microsoft SQL Server (T-SQL) dialect
27+
MsSql,
28+
/// `MySQL` SQL dialect
29+
MySql,
30+
/// Oracle SQL dialect
31+
Oracle,
32+
/// `PostgreSQL` SQL dialect
33+
#[default]
34+
PostgreSql,
35+
/// Amazon Redshift SQL dialect
36+
RedshiftSql,
37+
/// `SQLite` SQL dialect
38+
SQLite,
39+
/// Snowflake SQL dialect
40+
Snowflake,
41+
}
42+
43+
impl Dialects {
44+
/// Returns the dialect struct associated with the enum
45+
#[must_use]
46+
pub fn dialect(&self) -> &'static dyn Dialect {
47+
match self {
48+
Self::Ansi => &AnsiDialect {},
49+
Self::BigQuery => &BigQueryDialect {},
50+
Self::ClickHouse => &ClickHouseDialect {},
51+
Self::Databricks => &DatabricksDialect {},
52+
Self::DuckDb => &DuckDbDialect {},
53+
Self::Generic => &GenericDialect {},
54+
Self::Hive => &HiveDialect {},
55+
Self::MsSql => &MsSqlDialect {},
56+
Self::MySql => &MySqlDialect {},
57+
Self::Oracle => &OracleDialect {},
58+
Self::PostgreSql => &PostgreSqlDialect {},
59+
Self::RedshiftSql => &RedshiftSqlDialect {},
60+
Self::SQLite => &SQLiteDialect {},
61+
Self::Snowflake => &SnowflakeDialect {},
62+
}
63+
}
64+
}
65+
66+
#[cfg(test)]
67+
mod tests {
68+
use super::*;
69+
use std::any::Any;
70+
71+
#[test]
72+
fn default_is_postgres() {
73+
assert_eq!(Dialects::default(), Dialects::PostgreSql);
74+
}
75+
76+
fn assert_is<T: 'static>(d: Dialects) {
77+
let any = d.dialect() as &dyn Any;
78+
assert!(
79+
any.is::<T>(),
80+
"expected {:?} to map to {}, got a different dialect type",
81+
d,
82+
std::any::type_name::<T>(),
83+
);
84+
}
85+
86+
#[test]
87+
fn dialect_maps_to_correct_concrete_type() {
88+
assert_is::<AnsiDialect>(Dialects::Ansi);
89+
assert_is::<BigQueryDialect>(Dialects::BigQuery);
90+
assert_is::<ClickHouseDialect>(Dialects::ClickHouse);
91+
assert_is::<DatabricksDialect>(Dialects::Databricks);
92+
assert_is::<DuckDbDialect>(Dialects::DuckDb);
93+
assert_is::<GenericDialect>(Dialects::Generic);
94+
assert_is::<HiveDialect>(Dialects::Hive);
95+
assert_is::<MsSqlDialect>(Dialects::MsSql);
96+
assert_is::<MySqlDialect>(Dialects::MySql);
97+
assert_is::<OracleDialect>(Dialects::Oracle);
98+
assert_is::<PostgreSqlDialect>(Dialects::PostgreSql);
99+
assert_is::<RedshiftSqlDialect>(Dialects::RedshiftSql);
100+
assert_is::<SQLiteDialect>(Dialects::SQLite);
101+
assert_is::<SnowflakeDialect>(Dialects::Snowflake);
102+
}
103+
104+
#[test]
105+
fn dialect_reference_is_stable_for_same_variant() {
106+
let d = Dialects::PostgreSql;
107+
let p1 = std::ptr::from_ref::<dyn Dialect>(d.dialect());
108+
let p2 = std::ptr::from_ref::<dyn Dialect>(d.dialect());
109+
assert_eq!(p1, p2);
110+
}
111+
}

0 commit comments

Comments
 (0)