Skip to content

Commit 1349f86

Browse files
committed
Fix conditionals in multi statement functions
- when a multi statement block concludes (eg, `BEGIN`..`END`), that last `END` means we should *not* be expecting a statement delimiter
1 parent 17b88f4 commit 1349f86

File tree

2 files changed

+93
-3
lines changed

2 files changed

+93
-3
lines changed

Diff for: src/parser/mod.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,14 @@ impl<'a> Parser<'a> {
484484
}
485485

486486
let statement = self.parse_statement()?;
487+
expecting_statement_delimiter = match &statement {
488+
Statement::If(s) => match s.if_block.conditional_statements {
489+
ConditionalStatements::BeginEnd { .. } => false,
490+
_ => true,
491+
},
492+
_ => true
493+
};
487494
stmts.push(statement);
488-
expecting_statement_delimiter = true;
489495
}
490496
Ok(stmts)
491497
}

Diff for: tests/sqlparser_mssql.rs

+86-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
mod test_utils;
2424

2525
use helpers::attached_token::AttachedToken;
26-
use sqlparser::tokenizer::{Location, Span};
26+
use sqlparser::keywords::Keyword;
27+
use sqlparser::tokenizer::{Location, Span, TokenWithSpan};
2728
use test_utils::*;
2829

2930
use sqlparser::ast::DataType::{Int, Text, Varbinary};
@@ -326,6 +327,89 @@ fn parse_create_function() {
326327
remote_connection: None,
327328
}),
328329
);
330+
331+
let create_function_with_conditional = r#"
332+
CREATE FUNCTION some_scalar_udf()
333+
RETURNS INT
334+
AS
335+
BEGIN
336+
IF 1=2
337+
BEGIN
338+
RETURN 1;
339+
END
340+
341+
RETURN 0;
342+
END
343+
"#;
344+
let create_stmt = ms().one_statement_parses_to(create_function_with_conditional, "");
345+
assert_eq!(
346+
create_stmt,
347+
Statement::CreateFunction(CreateFunction {
348+
or_alter: false,
349+
or_replace: false,
350+
temporary: false,
351+
if_not_exists: false,
352+
name: ObjectName::from(vec![Ident {
353+
value: "some_scalar_udf".into(),
354+
quote_style: None,
355+
span: Span::empty(),
356+
}]),
357+
args: Some(vec![]),
358+
return_type: Some(DataType::Int(None)),
359+
function_body: Some(CreateFunctionBody::MultiStatement(vec![
360+
Statement::If(IfStatement {
361+
if_block: ConditionalStatementBlock {
362+
start_token: AttachedToken(TokenWithSpan::wrap(
363+
sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
364+
value: "IF".to_string(),
365+
quote_style: None,
366+
keyword: Keyword::IF
367+
})
368+
)),
369+
condition: Some(Expr::BinaryOp {
370+
left: Box::new(Expr::Value(
371+
Value::Number("1".to_string(), false).with_empty_span()
372+
)),
373+
op: sqlparser::ast::BinaryOperator::Eq,
374+
right: Box::new(Expr::Value(Value::Number("2".to_string(), false).with_empty_span())),
375+
}),
376+
then_token: None,
377+
conditional_statements: ConditionalStatements::BeginEnd {
378+
begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
379+
value: "BEGIN".to_string(),
380+
quote_style: None,
381+
keyword: Keyword::BEGIN
382+
}))),
383+
statements: vec![Statement::Return(ReturnStatement {
384+
value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))),
385+
})],
386+
end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
387+
value: "END".to_string(),
388+
quote_style: None,
389+
keyword: Keyword::END
390+
}))),
391+
},
392+
},
393+
elseif_blocks: vec![],
394+
else_block: None,
395+
end_token: None,
396+
}),
397+
Statement::Return(ReturnStatement {
398+
value: Some(ReturnStatementValue::Expr(Expr::Value(
399+
(number("0")).with_empty_span()
400+
))),
401+
}),
402+
])),
403+
behavior: None,
404+
called_on_null: None,
405+
parallel: None,
406+
using: None,
407+
language: None,
408+
determinism_specifier: None,
409+
options: None,
410+
remote_connection: None,
411+
})
412+
);
329413
}
330414

331415
#[test]
@@ -394,7 +478,7 @@ fn parse_mssql_create_function() {
394478
)),
395479
}],
396480
}),
397-
Statement::Return(ReturnStatement{
481+
Statement::Return(ReturnStatement {
398482
value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident {
399483
value: "@foo".into(),
400484
quote_style: None,

0 commit comments

Comments
 (0)