parent/child mode
This commit is contained in:
parent
ebb37ce5b8
commit
2b67cd1d01
@ -21,138 +21,16 @@ const ipcSocketArg = "--ipc-socket"
|
|||||||
|
|
||||||
type StdioMode int
|
type StdioMode int
|
||||||
|
|
||||||
type ipcMode int
|
type ipcCommon struct {
|
||||||
|
|
||||||
const (
|
|
||||||
modeParent ipcMode = 1
|
|
||||||
modeChild ipcMode = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
type KittenIPC struct {
|
|
||||||
mode ipcMode
|
|
||||||
cmd *exec.Cmd
|
|
||||||
localApi any
|
localApi any
|
||||||
|
|
||||||
socketPath string
|
socketPath string
|
||||||
listener net.Listener
|
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
errCh chan error
|
errCh chan error
|
||||||
|
|
||||||
nextId int64
|
nextId int64
|
||||||
pendingCalls map[int64]chan callResult
|
pendingCalls map[int64]chan callResult
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewParent(cmd *exec.Cmd, localApi any) (*KittenIPC, error) {
|
|
||||||
k := KittenIPC{
|
|
||||||
mode: modeParent,
|
|
||||||
cmd: cmd,
|
|
||||||
localApi: localApi,
|
|
||||||
pendingCalls: make(map[int64]chan callResult),
|
|
||||||
errCh: make(chan error, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
k.socketPath = filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid()))
|
|
||||||
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
|
|
||||||
if slices.Contains(cmd.Args, ipcSocketArg) {
|
|
||||||
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
|
|
||||||
}
|
|
||||||
cmd.Args = append(cmd.Args, ipcSocketArg, k.socketPath)
|
|
||||||
|
|
||||||
k.errCh = make(chan error, 1)
|
|
||||||
|
|
||||||
return &k, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChild(localApi any) (*KittenIPC, error) {
|
|
||||||
k := KittenIPC{
|
|
||||||
mode: modeChild,
|
|
||||||
localApi: localApi,
|
|
||||||
pendingCalls: make(map[int64]chan callResult),
|
|
||||||
errCh: make(chan error, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
socketPath := flag.String("ipc-socket", "", "Path to IPC socket")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if *socketPath == "" {
|
|
||||||
return nil, fmt.Errorf("ipc socket path is missing")
|
|
||||||
}
|
|
||||||
k.socketPath = *socketPath
|
|
||||||
|
|
||||||
return &k, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *KittenIPC) Start() error {
|
|
||||||
if k.mode == modeParent {
|
|
||||||
return k.startParent()
|
|
||||||
}
|
|
||||||
return k.startChild()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *KittenIPC) startParent() error {
|
|
||||||
_ = os.Remove(k.socketPath)
|
|
||||||
listener, err := net.Listen("unix", k.socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("listen unix socket: %w", err)
|
|
||||||
}
|
|
||||||
k.listener = listener
|
|
||||||
defer k.closeSock()
|
|
||||||
|
|
||||||
err = k.cmd.Start()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cmd start: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := k.acceptConn(); err != nil {
|
|
||||||
return fmt.Errorf("accept connection: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *KittenIPC) startChild() error {
|
|
||||||
conn, err := net.Dial("unix", k.socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("connect to parent socket: %w", err)
|
|
||||||
}
|
|
||||||
k.conn = conn
|
|
||||||
k.startRcvData()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *KittenIPC) acceptConn() error {
|
|
||||||
const acceptTimeout = time.Second * 10
|
|
||||||
|
|
||||||
res := make(chan mo.Result[net.Conn], 1)
|
|
||||||
go func() {
|
|
||||||
conn, err := k.listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
res <- mo.Err[net.Conn](err)
|
|
||||||
} else {
|
|
||||||
res <- mo.Ok[net.Conn](conn)
|
|
||||||
}
|
|
||||||
close(res)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-time.After(acceptTimeout):
|
|
||||||
_ = k.cmd.Process.Kill()
|
|
||||||
return fmt.Errorf("accept timeout")
|
|
||||||
case res := <-res:
|
|
||||||
if res.IsError() {
|
|
||||||
_ = k.cmd.Process.Kill()
|
|
||||||
return fmt.Errorf("accept: %w", res.Error())
|
|
||||||
}
|
|
||||||
k.conn = res.MustGet()
|
|
||||||
k.startRcvData()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type MsgType int
|
type MsgType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -174,40 +52,40 @@ type callResult struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) startRcvData() {
|
func (ipc *ipcCommon) startRcvData() {
|
||||||
scn := bufio.NewScanner(k.conn)
|
scn := bufio.NewScanner(ipc.conn)
|
||||||
for scn.Scan() {
|
for scn.Scan() {
|
||||||
var msg Message
|
var msg Message
|
||||||
if err := json.Unmarshal(scn.Bytes(), &msg); err != nil {
|
if err := json.Unmarshal(scn.Bytes(), &msg); err != nil {
|
||||||
k.raiseErr(fmt.Errorf("unmarshal message: %w", err))
|
ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
k.processMsg(msg)
|
ipc.processMsg(msg)
|
||||||
}
|
}
|
||||||
if err := scn.Err(); err != nil {
|
if err := scn.Err(); err != nil {
|
||||||
k.raiseErr(err)
|
ipc.raiseErr(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) processMsg(msg Message) {
|
func (ipc *ipcCommon) processMsg(msg Message) {
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case MsgCall:
|
case MsgCall:
|
||||||
k.handleCall(msg)
|
ipc.handleCall(msg)
|
||||||
case MsgResponse:
|
case MsgResponse:
|
||||||
k.handleResponse(msg)
|
ipc.handleResponse(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) handleCall(msg Message) {
|
func (ipc *ipcCommon) handleCall(msg Message) {
|
||||||
|
|
||||||
if k.localApi == nil {
|
if ipc.localApi == nil {
|
||||||
k.sendResponse(msg.Id, nil, fmt.Errorf("remote side does not accept ipc calls"))
|
ipc.sendResponse(msg.Id, nil, fmt.Errorf("remote side does not accept ipc calls"))
|
||||||
}
|
}
|
||||||
localApi := reflect.ValueOf(k.localApi)
|
localApi := reflect.ValueOf(ipc.localApi)
|
||||||
|
|
||||||
method := localApi.MethodByName(msg.Method)
|
method := localApi.MethodByName(msg.Method)
|
||||||
if !method.IsValid() {
|
if !method.IsValid() {
|
||||||
k.sendResponse(msg.Id, nil, fmt.Errorf("method not found: %s", msg.Method))
|
ipc.sendResponse(msg.Id, nil, fmt.Errorf("method not found: %s", msg.Method))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +93,7 @@ func (k *KittenIPC) handleCall(msg Message) {
|
|||||||
argsCount := methodType.NumIn()
|
argsCount := methodType.NumIn()
|
||||||
|
|
||||||
if len(msg.Params) != argsCount {
|
if len(msg.Params) != argsCount {
|
||||||
k.sendResponse(msg.Id, nil, fmt.Errorf("argument count mismatch: expected %d, got %d", argsCount, len(msg.Params)))
|
ipc.sendResponse(msg.Id, nil, fmt.Errorf("argument count mismatch: expected %d, got %d", argsCount, len(msg.Params)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,20 +111,20 @@ func (k *KittenIPC) handleCall(msg Message) {
|
|||||||
res = append(res, resVal)
|
res = append(res, resVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
k.sendResponse(msg.Id, res, resErr.Interface().(error))
|
ipc.sendResponse(msg.Id, res, resErr.Interface().(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) handleResponse(msg Message) {
|
func (ipc *ipcCommon) handleResponse(msg Message) {
|
||||||
|
|
||||||
k.mu.Lock()
|
ipc.mu.Lock()
|
||||||
ch, ok := k.pendingCalls[msg.Id]
|
ch, ok := ipc.pendingCalls[msg.Id]
|
||||||
if ok {
|
if ok {
|
||||||
delete(k.pendingCalls, msg.Id)
|
delete(ipc.pendingCalls, msg.Id)
|
||||||
}
|
}
|
||||||
k.mu.Unlock()
|
ipc.mu.Unlock()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
k.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
|
ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,7 +137,7 @@ func (k *KittenIPC) handleResponse(msg Message) {
|
|||||||
close(ch)
|
close(ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) sendResponse(id int64, result []any, err error) {
|
func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) {
|
||||||
msg := Message{
|
msg := Message{
|
||||||
Type: MsgResponse,
|
Type: MsgResponse,
|
||||||
Id: id,
|
Id: id,
|
||||||
@ -270,12 +148,12 @@ func (k *KittenIPC) sendResponse(id int64, result []any, err error) {
|
|||||||
msg.Error = err.Error()
|
msg.Error = err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := k.sendMsg(msg); err != nil {
|
if err := ipc.sendMsg(msg); err != nil {
|
||||||
k.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
|
ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) sendMsg(msg Message) error {
|
func (ipc *ipcCommon) sendMsg(msg Message) error {
|
||||||
|
|
||||||
data, err := json.Marshal(msg)
|
data, err := json.Marshal(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -284,20 +162,20 @@ func (k *KittenIPC) sendMsg(msg Message) error {
|
|||||||
|
|
||||||
data = append(data, '\n')
|
data = append(data, '\n')
|
||||||
|
|
||||||
if _, err := k.conn.Write(data); err != nil {
|
if _, err := ipc.conn.Write(data); err != nil {
|
||||||
return fmt.Errorf("write message: %w", err)
|
return fmt.Errorf("write message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) Call(method string, params ...any) ([]any, error) {
|
func (ipc *ipcCommon) Call(method string, params ...any) ([]any, error) {
|
||||||
k.mu.Lock()
|
ipc.mu.Lock()
|
||||||
id := k.nextId
|
id := ipc.nextId
|
||||||
k.nextId++
|
ipc.nextId++
|
||||||
resChan := make(chan callResult, 1)
|
resChan := make(chan callResult, 1)
|
||||||
k.pendingCalls[id] = resChan
|
ipc.pendingCalls[id] = resChan
|
||||||
k.mu.Unlock()
|
ipc.mu.Unlock()
|
||||||
|
|
||||||
msg := Message{
|
msg := Message{
|
||||||
Type: MsgCall,
|
Type: MsgCall,
|
||||||
@ -306,10 +184,10 @@ func (k *KittenIPC) Call(method string, params ...any) ([]any, error) {
|
|||||||
Params: params,
|
Params: params,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := k.sendMsg(msg); err != nil {
|
if err := ipc.sendMsg(msg); err != nil {
|
||||||
k.mu.Lock()
|
ipc.mu.Lock()
|
||||||
delete(k.pendingCalls, id)
|
delete(ipc.pendingCalls, id)
|
||||||
k.mu.Unlock()
|
ipc.mu.Unlock()
|
||||||
return nil, fmt.Errorf("send call: %w", err)
|
return nil, fmt.Errorf("send call: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -317,38 +195,111 @@ func (k *KittenIPC) Call(method string, params ...any) ([]any, error) {
|
|||||||
return result.result, result.err
|
return result.result, result.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) raiseErr(err error) {
|
func (ipc *ipcCommon) raiseErr(err error) {
|
||||||
select {
|
select {
|
||||||
case k.errCh <- err:
|
case ipc.errCh <- err:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) closeSock() error {
|
type ParentIPC struct {
|
||||||
if err := k.listener.Close(); err != nil {
|
*ipcCommon
|
||||||
|
cmd *exec.Cmd
|
||||||
|
listener net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewParent(cmd *exec.Cmd, localApi any) (*ParentIPC, error) {
|
||||||
|
p := ParentIPC{
|
||||||
|
ipcCommon: &ipcCommon{
|
||||||
|
localApi: localApi,
|
||||||
|
pendingCalls: make(map[int64]chan callResult),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid())),
|
||||||
|
},
|
||||||
|
cmd: cmd,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if slices.Contains(cmd.Args, ipcSocketArg) {
|
||||||
|
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
|
||||||
|
}
|
||||||
|
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
|
||||||
|
|
||||||
|
p.errCh = make(chan error, 1)
|
||||||
|
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParentIPC) Start() error {
|
||||||
|
_ = os.Remove(p.socketPath)
|
||||||
|
listener, err := net.Listen("unix", p.socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen unix socket: %w", err)
|
||||||
|
}
|
||||||
|
p.listener = listener
|
||||||
|
defer p.closeSock()
|
||||||
|
|
||||||
|
err = p.cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cmd start: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.acceptConn(); err != nil {
|
||||||
|
return fmt.Errorf("accept connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParentIPC) acceptConn() error {
|
||||||
|
const acceptTimeout = time.Second * 10
|
||||||
|
|
||||||
|
res := make(chan mo.Result[net.Conn], 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := p.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
res <- mo.Err[net.Conn](err)
|
||||||
|
} else {
|
||||||
|
res <- mo.Ok[net.Conn](conn)
|
||||||
|
}
|
||||||
|
close(res)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(acceptTimeout):
|
||||||
|
_ = p.cmd.Process.Kill()
|
||||||
|
return fmt.Errorf("accept timeout")
|
||||||
|
case res := <-res:
|
||||||
|
if res.IsError() {
|
||||||
|
_ = p.cmd.Process.Kill()
|
||||||
|
return fmt.Errorf("accept: %w", res.Error())
|
||||||
|
}
|
||||||
|
p.conn = res.MustGet()
|
||||||
|
p.startRcvData()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParentIPC) closeSock() error {
|
||||||
|
if err := p.listener.Close(); err != nil {
|
||||||
return fmt.Errorf("close socket listener: %w", err)
|
return fmt.Errorf("close socket listener: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) Wait() error {
|
func (p *ParentIPC) Wait() error {
|
||||||
if k.mode == modeParent {
|
|
||||||
return k.waitParent()
|
|
||||||
}
|
|
||||||
return k.waitChild()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *KittenIPC) waitParent() error {
|
|
||||||
waitErrCh := make(chan error, 1)
|
waitErrCh := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
waitErrCh <- k.cmd.Wait()
|
waitErrCh <- p.cmd.Wait()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-k.errCh:
|
case err := <-p.errCh:
|
||||||
runtimeErr := fmt.Errorf("runtime error: %w", err)
|
runtimeErr := fmt.Errorf("runtime error: %w", err)
|
||||||
killErr := k.cmd.Process.Kill()
|
killErr := p.cmd.Process.Kill()
|
||||||
return mergeErr(runtimeErr, killErr)
|
return mergeErr(runtimeErr, killErr)
|
||||||
case err := <-waitErrCh:
|
case err := <-waitErrCh:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -359,8 +310,42 @@ func (k *KittenIPC) waitParent() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KittenIPC) waitChild() error {
|
type ChildIPC struct {
|
||||||
err := <-k.errCh
|
*ipcCommon
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChild(localApi any) (*ChildIPC, error) {
|
||||||
|
c := ChildIPC{
|
||||||
|
ipcCommon: &ipcCommon{
|
||||||
|
localApi: localApi,
|
||||||
|
pendingCalls: make(map[int64]chan callResult),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
socketPath := flag.String("ipc-socket", "", "Path to IPC socket")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *socketPath == "" {
|
||||||
|
return nil, fmt.Errorf("ipc socket path is missing")
|
||||||
|
}
|
||||||
|
c.socketPath = *socketPath
|
||||||
|
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChildIPC) Start() error {
|
||||||
|
conn, err := net.Dial("unix", c.socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to parent socket: %w", err)
|
||||||
|
}
|
||||||
|
c.conn = conn
|
||||||
|
c.startRcvData()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChildIPC) Wait() error {
|
||||||
|
err := <-c.errCh
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("ipc error: %w", err)
|
return fmt.Errorf("ipc error: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user