mirror of https://github.com/jlelse/GoBlog
Rework TTS algorithm
This commit is contained in:
parent
720fc62919
commit
c8229ab28d
58
tts.go
58
tts.go
|
@ -13,11 +13,11 @@ import (
|
|||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/carlmjohnson/requests"
|
||||
"go.goblog.app/app/pkgs/bufferpool"
|
||||
"go.goblog.app/app/pkgs/mp3merge"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const ttsParameter = "tts"
|
||||
|
@ -72,52 +72,41 @@ func (a *goBlog) createPostTTSAudio(p *post) error {
|
|||
parts = append(parts, strings.Split(htmlText(a.postHtml(&postHtmlOptions{p: p})), "\n\n")...)
|
||||
|
||||
// Create TTS audio for each part
|
||||
partWriters := make([]io.Writer, len(parts))
|
||||
partReaders := make([]io.Reader, len(parts))
|
||||
for i := range parts {
|
||||
buf := bufferpool.Get()
|
||||
defer bufferpool.Put(buf)
|
||||
partWriters[i] = buf
|
||||
partReaders[i] = buf
|
||||
}
|
||||
errs := make([]error, len(parts))
|
||||
var wg sync.WaitGroup
|
||||
for i, part := range parts {
|
||||
// Increase wait group
|
||||
wg.Add(1)
|
||||
go func(i int, part string) {
|
||||
defer wg.Done()
|
||||
partReaders := []io.Reader{}
|
||||
var g errgroup.Group
|
||||
for _, part := range parts {
|
||||
part := part
|
||||
pr, pw := io.Pipe()
|
||||
defer func() {
|
||||
pw.Close()
|
||||
}()
|
||||
partReaders = append(partReaders, pr)
|
||||
g.Go(func() error {
|
||||
// Build SSML
|
||||
ssml := "<speak>" + html.EscapeString(part) + "<break time=\"500ms\"/></speak>"
|
||||
// Create TTS audio
|
||||
err := a.createTTSAudio(lang, ssml, partWriters[i])
|
||||
if err != nil {
|
||||
errs[i] = err
|
||||
return
|
||||
}
|
||||
}(i, part)
|
||||
err := a.createTTSAudio(lang, ssml, pw)
|
||||
_ = pw.CloseWithError(err)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for all parts to be created
|
||||
wg.Wait()
|
||||
|
||||
// Check if any errors occurred
|
||||
for _, err := range errs {
|
||||
// Merge parts together (needs buffer because the hash is needed before the file can be uploaded)
|
||||
buf := bufferpool.Get()
|
||||
defer bufferpool.Put(buf)
|
||||
hash := sha256.New()
|
||||
err := mp3merge.MergeMP3(io.MultiWriter(buf, hash), partReaders...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Merge partsBuffers into final buffer
|
||||
final := bufferpool.Get()
|
||||
defer bufferpool.Put(final)
|
||||
hash := sha256.New()
|
||||
if err := mp3merge.MergeMP3(io.MultiWriter(final, hash), partReaders...); err != nil {
|
||||
// Check if other errors appeared
|
||||
if err = g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save audio
|
||||
loc, err := a.saveMediaFile(fmt.Sprintf("%x.mp3", hash.Sum(nil)), final)
|
||||
loc, err := a.saveMediaFile(fmt.Sprintf("%x.mp3", hash.Sum(nil)), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -125,6 +114,7 @@ func (a *goBlog) createPostTTSAudio(p *post) error {
|
|||
return errors.New("no media location for tts audio")
|
||||
}
|
||||
|
||||
// Check existing tts parameter
|
||||
if old := p.firstParameter(ttsParameter); old != "" && old != loc {
|
||||
// Already has tts audio, but with different location
|
||||
// Try to delete the old audio file
|
||||
|
|
Loading…
Reference in New Issue