diff --git a/tts.go b/tts.go index 63f5ca4..02fca3a 100644 --- a/tts.go +++ b/tts.go @@ -13,11 +13,11 @@ import ( "net/url" "path" "strings" - "sync" "github.com/carlmjohnson/requests" "go.goblog.app/app/pkgs/bufferpool" "go.goblog.app/app/pkgs/mp3merge" + "golang.org/x/sync/errgroup" ) const ttsParameter = "tts" @@ -72,52 +72,41 @@ func (a *goBlog) createPostTTSAudio(p *post) error { parts = append(parts, strings.Split(htmlText(a.postHtml(&postHtmlOptions{p: p})), "\n\n")...) // Create TTS audio for each part - partWriters := make([]io.Writer, len(parts)) - partReaders := make([]io.Reader, len(parts)) - for i := range parts { - buf := bufferpool.Get() - defer bufferpool.Put(buf) - partWriters[i] = buf - partReaders[i] = buf - } - errs := make([]error, len(parts)) - var wg sync.WaitGroup - for i, part := range parts { - // Increase wait group - wg.Add(1) - go func(i int, part string) { - defer wg.Done() + partReaders := []io.Reader{} + var g errgroup.Group + for _, part := range parts { + part := part + pr, pw := io.Pipe() + defer func() { + pw.Close() + }() + partReaders = append(partReaders, pr) + g.Go(func() error { // Build SSML ssml := "" + html.EscapeString(part) + "" // Create TTS audio - err := a.createTTSAudio(lang, ssml, partWriters[i]) - if err != nil { - errs[i] = err - return - } - }(i, part) - } - - // Wait for all parts to be created - wg.Wait() - - // Check if any errors occurred - for _, err := range errs { - if err != nil { + err := a.createTTSAudio(lang, ssml, pw) + _ = pw.CloseWithError(err) return err - } + }) } - // Merge partsBuffers into final buffer - final := bufferpool.Get() - defer bufferpool.Put(final) + // Merge parts together (needs buffer because the hash is needed before the file can be uploaded) + buf := bufferpool.Get() + defer bufferpool.Put(buf) hash := sha256.New() - if err := mp3merge.MergeMP3(io.MultiWriter(final, hash), partReaders...); err != nil { + err := mp3merge.MergeMP3(io.MultiWriter(buf, hash), partReaders...) + if err != nil { + return err + } + + // Check if other errors appeared + if err = g.Wait(); err != nil { return err } // Save audio - loc, err := a.saveMediaFile(fmt.Sprintf("%x.mp3", hash.Sum(nil)), final) + loc, err := a.saveMediaFile(fmt.Sprintf("%x.mp3", hash.Sum(nil)), buf) if err != nil { return err } @@ -125,6 +114,7 @@ func (a *goBlog) createPostTTSAudio(p *post) error { return errors.New("no media location for tts audio") } + // Check existing tts parameter if old := p.firstParameter(ttsParameter); old != "" && old != loc { // Already has tts audio, but with different location // Try to delete the old audio file