@@ -6,11 +6,10 @@ use std::path::{Path, PathBuf};
66
77use sqlparser:: {
88 ast:: Statement ,
9- dialect:: GenericDialect ,
109 parser:: { Parser , ParserError } ,
1110} ;
1211
13- use crate :: source:: SqlSource ;
12+ use crate :: { dialects :: Dialects , source:: SqlSource } ;
1413
1514/// A single SQL file plus all [`Statement`].
1615#[ derive( Debug ) ]
@@ -30,9 +29,8 @@ impl ParsedSqlFile {
3029 ///
3130 /// # Errors
3231 /// - Returns [`ParserError`] if parsing fails
33- pub fn parse ( file : SqlSource ) -> Result < Self , ParserError > {
34- let dialect = GenericDialect { } ;
35- let statements = Parser :: parse_sql ( & dialect, file. content ( ) ) ?;
32+ pub fn parse ( file : SqlSource , dialect : & Dialects ) -> Result < Self , ParserError > {
33+ let statements = Parser :: parse_sql ( dialect. dialect ( ) , file. content ( ) ) ?;
3634 Ok ( Self { file, statements } )
3735 }
3836
@@ -82,7 +80,10 @@ impl ParsedSqlFileSet {
8280 /// # Errors
8381 /// - [`ParserError`] is returned for any errors parsing
8482 pub fn parse_all ( set : Vec < SqlSource > ) -> Result < Self , ParserError > {
85- let files = set. into_iter ( ) . map ( ParsedSqlFile :: parse) . collect :: < Result < Vec < _ > , _ > > ( ) ?;
83+ let files = set
84+ . into_iter ( )
85+ . map ( |s| ParsedSqlFile :: parse ( s, & Dialects :: default ( ) ) )
86+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
8687
8788 Ok ( Self { files } )
8889 }
@@ -110,7 +111,7 @@ mod tests {
110111 let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY);" ;
111112 fs:: write ( & file_path, sql) ?;
112113 let sql_file = SqlSource :: from_path ( & file_path) ?;
113- let parsed = ParsedSqlFile :: parse ( sql_file) ?;
114+ let parsed = ParsedSqlFile :: parse ( sql_file, & Dialects :: Generic ) ?;
114115 assert_eq ! ( parsed. path( ) , Some ( file_path. as_path( ) ) ) ;
115116 assert_eq ! ( parsed. content( ) , sql) ;
116117 assert_eq ! ( parsed. statements( ) . len( ) , 1 ) ;
@@ -147,4 +148,108 @@ mod tests {
147148 let _ = fs:: remove_dir_all ( & base) ;
148149 Ok ( ( ) )
149150 }
151+
152+ #[ test]
153+ fn parsed_sql_file_path_into_path_buf_round_trips ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
154+ let base = env:: temp_dir ( ) . join ( "parsed_sql_file_path_into_path_buf_round_trips" ) ;
155+ let _ = fs:: remove_dir_all ( & base) ;
156+ fs:: create_dir_all ( & base) ?;
157+ let file_path = base. join ( "one.sql" ) ;
158+ let sql = "CREATE TABLE t (id INTEGER PRIMARY KEY);" ;
159+ fs:: write ( & file_path, sql) ?;
160+ let sql_file = SqlSource :: from_path ( & file_path) ?;
161+ let parsed = ParsedSqlFile :: parse ( sql_file, & Dialects :: Generic ) ?;
162+ assert_eq ! ( parsed. path_into_path_buf( ) , Some ( file_path) ) ;
163+ let _ = fs:: remove_dir_all ( & base) ;
164+ Ok ( ( ) )
165+ }
166+
167+ #[ test]
168+ fn parsed_sql_file_parse_postgres_handles_pg_function_syntax ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
169+ let sql = r"
170+ CREATE OR REPLACE FUNCTION f()
171+ RETURNS SMALLINT
172+ LANGUAGE plpgsql
173+ SECURITY DEFINER
174+ STABLE
175+ AS $$
176+ BEGIN
177+ RETURN 4;
178+ END;
179+ $$;
180+
181+ CREATE TABLE t (id INTEGER PRIMARY KEY);
182+ " ;
183+
184+ let src = SqlSource :: from_str ( sql. to_owned ( ) , None ) ;
185+ let parsed = ParsedSqlFile :: parse ( src, & Dialects :: PostgreSql ) ?;
186+ assert ! (
187+ parsed. statements( ) . len( ) >= 2 ,
188+ "expected at least 2 statements (function + table)"
189+ ) ;
190+ assert ! (
191+ parsed
192+ . statements( )
193+ . iter( )
194+ . any( |s| matches!( s, Statement :: CreateTable { .. } ) ) ,
195+ "expected at least one CreateTable statement"
196+ ) ;
197+ Ok ( ( ) )
198+ }
199+
200+ #[ test]
201+ fn parsed_sql_file_set_parse_all_uses_default_dialect ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
202+ let base = env:: temp_dir ( ) . join ( "parsed_sql_file_set_parse_all_default_dialect" ) ;
203+ let _ = fs:: remove_dir_all ( & base) ;
204+ fs:: create_dir_all ( & base) ?;
205+
206+ let file1 = base. join ( "one.sql" ) ;
207+ let file2 = base. join ( "two.sql" ) ;
208+
209+ fs:: write ( & file1, "CREATE TABLE t1 (id INTEGER PRIMARY KEY);" ) ?;
210+
211+ let pg_sql = r"
212+ CREATE OR REPLACE FUNCTION f()
213+ RETURNS SMALLINT
214+ LANGUAGE plpgsql
215+ SECURITY DEFINER
216+ STABLE
217+ AS $$
218+ BEGIN
219+ RETURN 4;
220+ END;
221+ $$;
222+
223+ CREATE TABLE t2 (id INTEGER PRIMARY KEY);
224+ " ;
225+ fs:: write ( & file2, pg_sql) ?;
226+
227+ let set = SqlSource :: sql_sources ( & base, & [ ] ) ?;
228+ let parsed_set = ParsedSqlFileSet :: parse_all ( set) ?;
229+
230+ assert_eq ! ( parsed_set. files( ) . len( ) , 2 ) ;
231+
232+ for parsed in parsed_set. files ( ) {
233+ assert ! (
234+ parsed
235+ . statements( )
236+ . iter( )
237+ . any( |s| matches!( s, Statement :: CreateTable { .. } ) ) ,
238+ "expected CreateTable in parsed file; got statements: {:?}" ,
239+ parsed. statements( )
240+ ) ;
241+ }
242+
243+ let _ = fs:: remove_dir_all ( & base) ;
244+ Ok ( ( ) )
245+ }
246+
247+ #[ test]
248+ fn parsed_sql_file_parse_invalid_sql_returns_error ( ) {
249+ let sql = "CREATE TABLE" ;
250+ let src = SqlSource :: from_str ( sql. to_owned ( ) , None ) ;
251+ let res = ParsedSqlFile :: parse ( src, & Dialects :: Generic ) ;
252+ assert ! ( res. is_err( ) , "expected parse to fail for invalid SQL" ) ;
253+ }
254+
150255}
0 commit comments