Skip to content
Merged
6 changes: 5 additions & 1 deletion sdks/go/pkg/beam/runners/prism/internal/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,11 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic
eg.Go(func() error {
s := stages[rb.StageID]
wk := wks[s.envID]
if err := s.Execute(ctx, j, wk, comps, em, rb); err != nil {
// Pass egctx instead of the parent ctx so that when any bundle fails,
// the errgroup cancels egctx and all other concurrent bundle execution
// goroutines immediately detect cancellation and abort. This prevents
// eg.Wait() from blocking indefinitely and allows prompt error reporting.
if err := s.Execute(egctx, j, wk, comps, em, rb); err != nil {
// Ensure we clean up on bundle failure
j.Logger.Error("Bundle Failed.", slog.Any("error", err))
em.FailBundle(rb)
Expand Down
22 changes: 22 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,28 @@ func TestFailure(t *testing.T) {
}
}

func TestFailureHang(t *testing.T) {
initRunner(t)

p, s := beam.NewPipelineWithRoot()
imp := beam.Impulse(s)
col1 := beam.ParDo(s, doFnBlock, imp)
col2 := beam.ParDo(s, doFnFail, imp)
beam.ParDo(s, &int64Check{Name: "block", Want: []int{}}, col1)
beam.ParDo(s, &int64Check{Name: "fail", Want: []int{}}, col2)

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

_, err := executeWithT(ctx, t, p)
if err == nil {
t.Fatalf("expected pipeline failure, but got a success")
}
if want := "doFnFail: failing as intended"; !strings.Contains(err.Error(), want) {
t.Fatalf("expected pipeline failure with %q, but was %v", want, err)
}
}

func TestRunner_Passert(t *testing.T) {
initRunner(t)
tests := []struct {
Expand Down
5 changes: 5 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func init() {
register.Function3x0(dofn1Counter)
register.Function2x0(dofnSink)
register.Function3x1(doFnFail)
register.Function3x0(doFnBlock)

register.Function2x1(combineIntSum)

Expand Down Expand Up @@ -283,6 +284,10 @@ func doFnFail(ctx context.Context, _ []byte, emit func(int64)) error {
return fmt.Errorf("doFnFail: failing as intended")
}

func doFnBlock(ctx context.Context, _ []byte, emit func(int64)) {
<-ctx.Done()
}

func combineIntSum(a, b int64) int64 {
return a + b
}
Expand Down
Loading