diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index 853b7974479d..f6e148f9f3f6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -376,7 +376,11 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic eg.Go(func() error { s := stages[rb.StageID] wk := wks[s.envID] - if err := s.Execute(ctx, j, wk, comps, em, rb); err != nil { + // Pass egctx instead of the parent ctx so that when any bundle fails, + // the errgroup cancels egctx and all other concurrent bundle execution + // goroutines immediately detect cancellation and abort. This prevents + // eg.Wait() from blocking indefinitely and allows prompt error reporting. + if err := s.Execute(egctx, j, wk, comps, em, rb); err != nil { // Ensure we clean up on bundle failure j.Logger.Error("Bundle Failed.", slog.Any("error", err)) em.FailBundle(rb) diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go index 29fccaeb238e..2bb73f20e200 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go @@ -519,6 +519,28 @@ func TestFailure(t *testing.T) { } } +func TestFailureHang(t *testing.T) { + initRunner(t) + + p, s := beam.NewPipelineWithRoot() + imp := beam.Impulse(s) + col1 := beam.ParDo(s, doFnBlock, imp) + col2 := beam.ParDo(s, doFnFail, imp) + beam.ParDo(s, &int64Check{Name: "block", Want: []int{}}, col1) + beam.ParDo(s, &int64Check{Name: "fail", Want: []int{}}, col2) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err := executeWithT(ctx, t, p) + if err == nil { + t.Fatalf("expected pipeline failure, but got a success") + } + if want := "doFnFail: failing as intended"; !strings.Contains(err.Error(), want) { + t.Fatalf("expected pipeline failure with %q, but was %v", want, err) + } +} + func TestRunner_Passert(t *testing.T) { initRunner(t) tests := []struct { diff --git a/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go b/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go index 334d74fcae1d..d21ccd53afd0 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go @@ -59,6 +59,7 @@ func init() { register.Function3x0(dofn1Counter) register.Function2x0(dofnSink) register.Function3x1(doFnFail) + register.Function3x0(doFnBlock) register.Function2x1(combineIntSum) @@ -283,6 +284,10 @@ func doFnFail(ctx context.Context, _ []byte, emit func(int64)) error { return fmt.Errorf("doFnFail: failing as intended") } +func doFnBlock(ctx context.Context, _ []byte, emit func(int64)) { + <-ctx.Done() +} + func combineIntSum(a, b int64) int64 { return a + b }