Skip to content

Commit e8f2c69

Browse files
authored
Merge pull request #12 from poyrazK/feature/fix-joins-and-nulls
feat: Implement LEFT JOIN and fix NULL handling in PostgreSQL protocol
2 parents 0f756c3 + 0b8d82b commit e8f2c69

13 files changed

Lines changed: 636 additions & 61 deletions

File tree

include/executor/operator.hpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,28 +275,42 @@ class AggregateOperator : public Operator {
275275
* @brief Hash join operator
276276
*/
277277
class HashJoinOperator : public Operator {
278+
public:
279+
using JoinType = cloudsql::executor::JoinType;
280+
278281
private:
282+
struct BuildTuple {
283+
Tuple tuple;
284+
bool matched = false;
285+
};
286+
279287
std::unique_ptr<Operator> left_;
280288
std::unique_ptr<Operator> right_;
281289
std::unique_ptr<parser::Expression> left_key_;
282290
std::unique_ptr<parser::Expression> right_key_;
291+
JoinType join_type_;
283292
Schema schema_;
284293

285294
/* In-memory hash table for the right side */
286-
std::unordered_multimap<std::string, Tuple> hash_table_;
295+
std::unordered_multimap<std::string, BuildTuple> hash_table_;
287296

288297
/* Probe phase state */
289298
std::optional<Tuple> left_tuple_;
299+
bool left_had_match_ = false;
290300
struct MatchIterator {
291-
std::unordered_multimap<std::string, Tuple>::iterator current;
292-
std::unordered_multimap<std::string, Tuple>::iterator end;
301+
std::unordered_multimap<std::string, BuildTuple>::iterator current;
302+
std::unordered_multimap<std::string, BuildTuple>::iterator end;
293303
};
294304
std::optional<MatchIterator> match_iter_;
295305

306+
/* Final phase for RIGHT/FULL joins */
307+
std::optional<std::unordered_multimap<std::string, BuildTuple>::iterator> right_idx_iter_;
308+
296309
public:
297310
HashJoinOperator(std::unique_ptr<Operator> left, std::unique_ptr<Operator> right,
298311
std::unique_ptr<parser::Expression> left_key,
299-
std::unique_ptr<parser::Expression> right_key);
312+
std::unique_ptr<parser::Expression> right_key,
313+
JoinType join_type = JoinType::Inner);
300314

301315
bool init() override;
302316
bool open() override;

include/executor/types.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ namespace cloudsql::executor {
2525
*/
2626
enum class ExecState : uint8_t { Init, Open, Executing, Done, Error };
2727

28+
/**
29+
* @brief Supported join types for relation merging.
30+
*/
31+
enum class JoinType : uint8_t { Inner, Left, Right, Full };
32+
2833
/**
2934
* @brief Supported aggregation functions for analytical queries.
3035
*/

src/executor/operator.cpp

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -565,19 +565,29 @@ Schema& AggregateOperator::output_schema() {
565565

566566
HashJoinOperator::HashJoinOperator(std::unique_ptr<Operator> left, std::unique_ptr<Operator> right,
567567
std::unique_ptr<parser::Expression> left_key,
568-
std::unique_ptr<parser::Expression> right_key)
568+
std::unique_ptr<parser::Expression> right_key,
569+
executor::JoinType join_type)
569570
: Operator(OperatorType::HashJoin, left->get_txn(), left->get_lock_manager()),
570571
left_(std::move(left)),
571572
right_(std::move(right)),
572573
left_key_(std::move(left_key)),
573-
right_key_(std::move(right_key)) {
574+
right_key_(std::move(right_key)),
575+
join_type_(join_type) {
574576
/* Build resulting schema */
575577
if (left_ && right_) {
576578
for (const auto& col : left_->output_schema().columns()) {
577-
schema_.add_column(col);
579+
auto col_meta = col;
580+
if (join_type_ == executor::JoinType::Right || join_type_ == executor::JoinType::Full) {
581+
col_meta.set_nullable(true);
582+
}
583+
schema_.add_column(col_meta);
578584
}
579585
for (const auto& col : right_->output_schema().columns()) {
580-
schema_.add_column(col);
586+
auto col_meta = col;
587+
if (join_type_ == executor::JoinType::Left || join_type_ == executor::JoinType::Full) {
588+
col_meta.set_nullable(true);
589+
}
590+
schema_.add_column(col_meta);
581591
}
582592
}
583593
}
@@ -597,62 +607,107 @@ bool HashJoinOperator::open() {
597607
auto right_schema = right_->output_schema();
598608
while (right_->next(right_tuple)) {
599609
const common::Value key = right_key_->evaluate(&right_tuple, &right_schema);
600-
hash_table_.emplace(key.to_string(), std::move(right_tuple));
610+
hash_table_.emplace(key.to_string(), BuildTuple{std::move(right_tuple), false});
601611
}
602612

603613
left_tuple_ = std::nullopt;
604614
match_iter_ = std::nullopt;
615+
left_had_match_ = false;
616+
right_idx_iter_ = std::nullopt;
605617
set_state(ExecState::Open);
606618
return true;
607619
}
608620

609621
bool HashJoinOperator::next(Tuple& out_tuple) {
610622
auto left_schema = left_->output_schema();
623+
auto right_schema = right_->output_schema();
611624

612625
while (true) {
613626
if (match_iter_.has_value()) {
614-
/* We are currently iterating through matches for a left tuple */
615627
auto& iter_state = match_iter_.value();
616628
if (iter_state.current != iter_state.end) {
617-
const auto& right_tuple = iter_state.current->second;
618-
619-
/* Concatenate left and right tuples */
620-
if (left_tuple_.has_value()) {
621-
std::vector<common::Value> joined_values = left_tuple_->values();
622-
joined_values.insert(joined_values.end(), right_tuple.values().begin(),
623-
right_tuple.values().end());
629+
auto& build_tuple = iter_state.current->second;
630+
const auto& right_tuple = build_tuple.tuple;
631+
std::vector<common::Value> joined_values = left_tuple_->values();
632+
joined_values.insert(joined_values.end(), right_tuple.values().begin(),
633+
right_tuple.values().end());
634+
635+
out_tuple = Tuple(std::move(joined_values));
636+
iter_state.current++;
637+
left_had_match_ = true;
638+
build_tuple.matched = true;
639+
return true;
640+
}
624641

625-
out_tuple = Tuple(std::move(joined_values));
626-
iter_state.current++;
627-
return true;
642+
/* No more matches for this left tuple. If (LEFT or FULL join) and no matches found,
643+
* emit NULLs */
644+
match_iter_ = std::nullopt;
645+
if ((join_type_ == JoinType::Left || join_type_ == JoinType::Full) &&
646+
!left_had_match_) {
647+
std::vector<common::Value> joined_values = left_tuple_->values();
648+
for (size_t i = 0; i < right_schema.column_count(); ++i) {
649+
joined_values.push_back(common::Value::make_null());
628650
}
651+
out_tuple = Tuple(std::move(joined_values));
652+
left_tuple_ = std::nullopt;
653+
return true;
629654
}
630-
/* No more matches for this left tuple */
631-
match_iter_ = std::nullopt;
632655
left_tuple_ = std::nullopt;
633656
}
634657

635658
/* Pull next tuple from left side */
636659
Tuple next_left;
637-
if (!left_->next(next_left)) {
638-
set_state(ExecState::Done);
639-
return false;
640-
}
641-
642-
left_tuple_ = std::move(next_left);
643-
if (left_tuple_.has_value()) {
660+
if (left_->next(next_left)) {
661+
left_tuple_ = std::move(next_left);
662+
left_had_match_ = false;
644663
const common::Value key = left_key_->evaluate(&(left_tuple_.value()), &left_schema);
645664

646665
/* Look up in hash table */
647666
auto range = hash_table_.equal_range(key.to_string());
648667
if (range.first != range.second) {
649668
match_iter_ = {range.first, range.second};
650-
/* Continue loop to return the first match */
669+
} else if (join_type_ == JoinType::Left || join_type_ == JoinType::Full) {
670+
/* No match found immediately, emit NULLs if Left/Full join */
671+
std::vector<common::Value> joined_values = left_tuple_->values();
672+
for (size_t i = 0; i < right_schema.column_count(); ++i) {
673+
joined_values.push_back(common::Value::make_null());
674+
}
675+
out_tuple = Tuple(std::move(joined_values));
676+
left_tuple_ = std::nullopt;
677+
return true;
651678
} else {
652-
/* No match for this left tuple, pull next */
679+
/* Inner/Right join and no match, skip to next left tuple */
653680
left_tuple_ = std::nullopt;
654681
}
682+
continue;
683+
}
684+
685+
/* Probe phase done. For RIGHT or FULL joins, scan hash table for unmatched right tuples */
686+
if (join_type_ == JoinType::Right || join_type_ == JoinType::Full) {
687+
if (!right_idx_iter_.has_value()) {
688+
right_idx_iter_ = hash_table_.begin();
689+
}
690+
691+
auto& it = right_idx_iter_.value();
692+
while (it != hash_table_.end()) {
693+
if (!it->second.matched) {
694+
std::vector<common::Value> joined_values;
695+
for (size_t i = 0; i < left_schema.column_count(); ++i) {
696+
joined_values.push_back(common::Value::make_null());
697+
}
698+
joined_values.insert(joined_values.end(), it->second.tuple.values().begin(),
699+
it->second.tuple.values().end());
700+
out_tuple = Tuple(std::move(joined_values));
701+
it->second.matched = true; /* Mark as emitted */
702+
it++;
703+
return true;
704+
}
705+
it++;
706+
}
655707
}
708+
709+
set_state(ExecState::Done);
710+
return false;
656711
}
657712
}
658713

src/executor/query_executor.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,18 @@ std::unique_ptr<Operator> QueryExecutor::build_plan(const parser::SelectStatemen
649649
}
650650

651651
if (use_hash_join) {
652-
current_root =
653-
std::make_unique<HashJoinOperator>(std::move(current_root), std::move(join_scan),
654-
std::move(left_key), std::move(right_key));
652+
executor::JoinType exec_join_type = executor::JoinType::Inner;
653+
if (join.type == parser::SelectStatement::JoinType::Left) {
654+
exec_join_type = executor::JoinType::Left;
655+
} else if (join.type == parser::SelectStatement::JoinType::Right) {
656+
exec_join_type = executor::JoinType::Right;
657+
} else if (join.type == parser::SelectStatement::JoinType::Full) {
658+
exec_join_type = executor::JoinType::Full;
659+
}
660+
661+
current_root = std::make_unique<HashJoinOperator>(
662+
std::move(current_root), std::move(join_scan), std::move(left_key),
663+
std::move(right_key), exec_join_type);
655664
} else {
656665
/* TODO: Implement NestedLoopJoin for non-equality or missing conditions */
657666
return nullptr;

src/network/server.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -414,25 +414,42 @@ void Server::handle_connection(int client_fd) {
414414
for (const auto& row : res.rows()) {
415415
const char d_type = 'D';
416416
uint32_t d_len = 4 + 2; // len + num_cols
417-
std::vector<std::string> str_vals;
417+
418+
struct ColValue {
419+
bool is_null;
420+
std::string val;
421+
};
422+
std::vector<ColValue> col_vals;
423+
col_vals.reserve(num_cols);
424+
418425
for (uint32_t i = 0; i < num_cols; ++i) {
419-
const std::string s_val = row.get(i).to_string();
420-
str_vals.push_back(s_val);
421-
d_len +=
422-
4 + static_cast<uint32_t>(s_val.size()); // len + value
426+
const auto& v = row.get(i);
427+
if (v.is_null()) {
428+
col_vals.push_back({true, ""});
429+
d_len += 4;
430+
} else {
431+
std::string s_val = v.to_string();
432+
d_len += 4 + static_cast<uint32_t>(s_val.size());
433+
col_vals.push_back({false, std::move(s_val)});
434+
}
423435
}
424436

425437
const uint32_t net_d_len = htonl(d_len);
426438
static_cast<void>(send(client_fd, &d_type, 1, 0));
427439
static_cast<void>(send(client_fd, &net_d_len, 4, 0));
428440
static_cast<void>(send(client_fd, &net_num_cols, 2, 0));
429441

430-
for (const auto& s_val : str_vals) {
431-
const uint32_t val_len =
432-
htonl(static_cast<uint32_t>(s_val.size()));
433-
static_cast<void>(send(client_fd, &val_len, 4, 0));
434-
static_cast<void>(
435-
send(client_fd, s_val.c_str(), s_val.size(), 0));
442+
for (const auto& cv : col_vals) {
443+
if (cv.is_null) {
444+
const uint32_t null_len = 0xFFFFFFFF;
445+
static_cast<void>(send(client_fd, &null_len, 4, 0));
446+
} else {
447+
const uint32_t val_len =
448+
htonl(static_cast<uint32_t>(cv.val.size()));
449+
static_cast<void>(send(client_fd, &val_len, 4, 0));
450+
static_cast<void>(
451+
send(client_fd, cv.val.c_str(), cv.val.size(), 0));
452+
}
436453
}
437454
}
438455
}

src/parser/expression.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,16 @@ common::Value ColumnExpr::evaluate(const executor::Tuple* tuple,
234234
return common::Value::make_null();
235235
}
236236

237-
const size_t index = schema->find_column(name_);
237+
size_t index = static_cast<size_t>(-1);
238+
239+
/* 1. Try exact match (either fully qualified or just name) */
240+
index = schema->find_column(this->to_string());
241+
242+
/* 2. If not found and it's qualified, try just the column name */
243+
if (index == static_cast<size_t>(-1) && has_table()) {
244+
index = schema->find_column(name_);
245+
}
246+
238247
if (index == static_cast<size_t>(-1)) {
239248
return common::Value::make_null();
240249
}
@@ -245,7 +254,16 @@ common::Value ColumnExpr::evaluate(const executor::Tuple* tuple,
245254
void ColumnExpr::evaluate_vectorized(const executor::VectorBatch& batch,
246255
const executor::Schema& schema,
247256
executor::ColumnVector& result) const {
248-
const size_t index = schema.find_column(name_);
257+
size_t index = static_cast<size_t>(-1);
258+
259+
/* 1. Try exact match (either fully qualified or just name) */
260+
index = schema.find_column(this->to_string());
261+
262+
/* 2. If not found and it's qualified, try just the column name */
263+
if (index == static_cast<size_t>(-1) && has_table()) {
264+
index = schema.find_column(name_);
265+
}
266+
249267
result.clear();
250268
if (index == static_cast<size_t>(-1)) {
251269
for (size_t i = 0; i < batch.row_count(); ++i) {

tests/cloudSQL_tests.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,9 @@ TEST(ParserAdvanced, JoinAndComplexSelect) {
784784
/* 1. Left Join and multiple joins */
785785
{
786786
auto lexer = std::make_unique<Lexer>(
787-
"SELECT a.id, b.val FROM t1 LEFT JOIN t2 ON a.id = b.id JOIN t3 ON b.x = t3.x WHERE "
788-
"a.id > 10");
787+
"SELECT t1.id, t2.val FROM t1 LEFT JOIN t2 ON t1.id = t2.id JOIN t3 ON t2.x = t3.x "
788+
"WHERE "
789+
"t1.id > 10");
789790
Parser parser(std::move(lexer));
790791
auto stmt = parser.parse_statement();
791792
ASSERT_NE(stmt, nullptr);

tests/logic/aggregates.slt

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Aggregate and Group By Tests
2+
3+
statement ok
4+
CREATE TABLE agg_test (grp TEXT, val INT);
5+
6+
statement ok
7+
INSERT INTO agg_test VALUES ('A', 10), ('A', 20), ('B', 5), ('B', 15), ('C', 100);
8+
9+
# Basic Aggregates
10+
query IIII
11+
SELECT SUM(val), COUNT(val), MIN(val), MAX(val) FROM agg_test;
12+
----
13+
150 5 5 100
14+
15+
# Group By
16+
query TI
17+
SELECT grp, SUM(val) FROM agg_test GROUP BY grp ORDER BY grp;
18+
----
19+
A 30
20+
B 20
21+
C 100
22+
23+
# Group By with Filter
24+
query TI
25+
SELECT grp, COUNT(val) FROM agg_test WHERE val > 10 GROUP BY grp ORDER BY grp;
26+
----
27+
A 1
28+
B 1
29+
C 1
30+
31+
# Having Clause
32+
query TI
33+
SELECT grp, SUM(val) FROM agg_test GROUP BY grp HAVING SUM(val) > 25 ORDER BY grp;
34+
----
35+
A 30
36+
C 100
37+
38+
# Average (Real)
39+
query R
40+
SELECT AVG(val) FROM agg_test WHERE grp = 'A';
41+
----
42+
15.0
43+
44+
statement ok
45+
DROP TABLE agg_test;

0 commit comments

Comments
 (0)