166 lines
3.6 KiB
Go
166 lines
3.6 KiB
Go
package golang
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"slices"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type ParentIPC struct {
|
|
*ipcCommon
|
|
cmd *exec.Cmd
|
|
listener net.Listener
|
|
cmdDone chan struct{}
|
|
cmdErr error
|
|
}
|
|
|
|
func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
|
return NewParentWithContext(context.Background(), cmd, localApis...)
|
|
}
|
|
|
|
func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
|
p := ParentIPC{
|
|
ipcCommon: &ipcCommon{
|
|
localApis: mapTypeNames(localApis),
|
|
pendingCalls: make(map[int64]*pendingCall),
|
|
errCh: make(chan error, 1),
|
|
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())),
|
|
ctx: ctx,
|
|
},
|
|
cmd: cmd,
|
|
}
|
|
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
|
|
if slices.Contains(cmd.Args, ipcSocketArg) {
|
|
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
|
|
}
|
|
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
|
|
|
|
p.errCh = make(chan error, 1)
|
|
p.cmdDone = make(chan struct{})
|
|
|
|
return &p, nil
|
|
}
|
|
|
|
func (p *ParentIPC) Start() error {
|
|
_ = os.Remove(p.socketPath)
|
|
listener, err := net.Listen("unix", p.socketPath)
|
|
if err != nil {
|
|
return fmt.Errorf("listen unix socket: %w", err)
|
|
}
|
|
p.listener = listener
|
|
defer p.listener.Close()
|
|
|
|
err = p.cmd.Start()
|
|
if err != nil {
|
|
return fmt.Errorf("cmd start: %w", err)
|
|
}
|
|
|
|
go func() {
|
|
p.cmdErr = p.cmd.Wait()
|
|
close(p.cmdDone)
|
|
}()
|
|
|
|
return p.acceptConn()
|
|
}
|
|
|
|
type connResult struct {
|
|
conn net.Conn
|
|
err error
|
|
}
|
|
|
|
func (p *ParentIPC) acceptConn() error {
|
|
res := make(chan connResult, 1)
|
|
go func() {
|
|
conn, err := p.listener.Accept()
|
|
res <- connResult{conn: conn, err: err}
|
|
close(res)
|
|
}()
|
|
|
|
select {
|
|
case <-time.After(time.Duration(defaultAcceptTimeout) * time.Second):
|
|
_ = p.cmd.Process.Kill()
|
|
return fmt.Errorf("accept timeout")
|
|
case <-p.cmdDone:
|
|
return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr)
|
|
case r := <-res:
|
|
if r.err != nil {
|
|
_ = p.cmd.Process.Kill()
|
|
return fmt.Errorf("accept: %w", r.err)
|
|
}
|
|
p.conn = r.conn
|
|
go p.readConn()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *ParentIPC) Stop() error {
|
|
p.mu.Lock()
|
|
hasPending := len(p.pendingCalls) > 0
|
|
p.mu.Unlock()
|
|
if hasPending {
|
|
return fmt.Errorf("there are calls pending")
|
|
}
|
|
if p.processingCalls.Load() > 0 {
|
|
return fmt.Errorf("there are calls processing")
|
|
}
|
|
p.stopRequested.Store(true)
|
|
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
|
return fmt.Errorf("send SIGINT: %w", err)
|
|
}
|
|
return p.Wait()
|
|
}
|
|
|
|
func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) {
|
|
const maxDuration = time.Duration(1<<63 - 1)
|
|
_timeout := maxDuration
|
|
if len(timeout) > 0 {
|
|
_timeout = timeout[0]
|
|
}
|
|
|
|
loop:
|
|
for {
|
|
select {
|
|
case err := <-p.errCh:
|
|
retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err))
|
|
break loop
|
|
case <-p.cmdDone:
|
|
err := p.cmdErr
|
|
if err != nil {
|
|
var exitErr *exec.ExitError
|
|
if ok := errors.As(err, &exitErr); ok {
|
|
if !exitErr.Success() {
|
|
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
|
if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) {
|
|
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
|
}
|
|
}
|
|
} else {
|
|
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
|
}
|
|
}
|
|
break loop
|
|
case <-time.After(_timeout):
|
|
p.stopRequested.Store(true)
|
|
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
|
retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
p.closeConn()
|
|
_ = os.Remove(p.socketPath)
|
|
|
|
return retErr
|
|
}
|