diff --git a/lib/golang/lib.go b/lib/golang/lib.go index b29b4b5..8a448eb 100644 --- a/lib/golang/lib.go +++ b/lib/golang/lib.go @@ -21,138 +21,16 @@ const ipcSocketArg = "--ipc-socket" type StdioMode int -type ipcMode int - -const ( - modeParent ipcMode = 1 - modeChild ipcMode = 2 -) - -type KittenIPC struct { - mode ipcMode - cmd *exec.Cmd - localApi any - - socketPath string - listener net.Listener - conn net.Conn - errCh chan error - +type ipcCommon struct { + localApi any + socketPath string + conn net.Conn + errCh chan error nextId int64 pendingCalls map[int64]chan callResult mu sync.Mutex } -func NewParent(cmd *exec.Cmd, localApi any) (*KittenIPC, error) { - k := KittenIPC{ - mode: modeParent, - cmd: cmd, - localApi: localApi, - pendingCalls: make(map[int64]chan callResult), - errCh: make(chan error, 1), - } - - k.socketPath = filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid())) - - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if slices.Contains(cmd.Args, ipcSocketArg) { - return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg) - } - cmd.Args = append(cmd.Args, ipcSocketArg, k.socketPath) - - k.errCh = make(chan error, 1) - - return &k, nil -} - -func NewChild(localApi any) (*KittenIPC, error) { - k := KittenIPC{ - mode: modeChild, - localApi: localApi, - pendingCalls: make(map[int64]chan callResult), - errCh: make(chan error, 1), - } - - socketPath := flag.String("ipc-socket", "", "Path to IPC socket") - flag.Parse() - - if *socketPath == "" { - return nil, fmt.Errorf("ipc socket path is missing") - } - k.socketPath = *socketPath - - return &k, nil -} - -func (k *KittenIPC) Start() error { - if k.mode == modeParent { - return k.startParent() - } - return k.startChild() -} - -func (k *KittenIPC) startParent() error { - _ = os.Remove(k.socketPath) - listener, err := net.Listen("unix", k.socketPath) - if err != nil { - return fmt.Errorf("listen unix socket: %w", err) - } - k.listener = listener - defer k.closeSock() - - err = k.cmd.Start() - if err != nil { - return fmt.Errorf("cmd start: %w", err) - } - - if err := k.acceptConn(); err != nil { - return fmt.Errorf("accept connection: %w", err) - } - - return nil -} - -func (k *KittenIPC) startChild() error { - conn, err := net.Dial("unix", k.socketPath) - if err != nil { - return fmt.Errorf("connect to parent socket: %w", err) - } - k.conn = conn - k.startRcvData() - return nil -} - -func (k *KittenIPC) acceptConn() error { - const acceptTimeout = time.Second * 10 - - res := make(chan mo.Result[net.Conn], 1) - go func() { - conn, err := k.listener.Accept() - if err != nil { - res <- mo.Err[net.Conn](err) - } else { - res <- mo.Ok[net.Conn](conn) - } - close(res) - }() - - select { - case <-time.After(acceptTimeout): - _ = k.cmd.Process.Kill() - return fmt.Errorf("accept timeout") - case res := <-res: - if res.IsError() { - _ = k.cmd.Process.Kill() - return fmt.Errorf("accept: %w", res.Error()) - } - k.conn = res.MustGet() - k.startRcvData() - } - return nil -} - type MsgType int const ( @@ -174,40 +52,40 @@ type callResult struct { err error } -func (k *KittenIPC) startRcvData() { - scn := bufio.NewScanner(k.conn) +func (ipc *ipcCommon) startRcvData() { + scn := bufio.NewScanner(ipc.conn) for scn.Scan() { var msg Message if err := json.Unmarshal(scn.Bytes(), &msg); err != nil { - k.raiseErr(fmt.Errorf("unmarshal message: %w", err)) + ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err)) break } - k.processMsg(msg) + ipc.processMsg(msg) } if err := scn.Err(); err != nil { - k.raiseErr(err) + ipc.raiseErr(err) } } -func (k *KittenIPC) processMsg(msg Message) { +func (ipc *ipcCommon) processMsg(msg Message) { switch msg.Type { case MsgCall: - k.handleCall(msg) + ipc.handleCall(msg) case MsgResponse: - k.handleResponse(msg) + ipc.handleResponse(msg) } } -func (k *KittenIPC) handleCall(msg Message) { +func (ipc *ipcCommon) handleCall(msg Message) { - if k.localApi == nil { - k.sendResponse(msg.Id, nil, fmt.Errorf("remote side does not accept ipc calls")) + if ipc.localApi == nil { + ipc.sendResponse(msg.Id, nil, fmt.Errorf("remote side does not accept ipc calls")) } - localApi := reflect.ValueOf(k.localApi) + localApi := reflect.ValueOf(ipc.localApi) method := localApi.MethodByName(msg.Method) if !method.IsValid() { - k.sendResponse(msg.Id, nil, fmt.Errorf("method not found: %s", msg.Method)) + ipc.sendResponse(msg.Id, nil, fmt.Errorf("method not found: %s", msg.Method)) return } @@ -215,7 +93,7 @@ func (k *KittenIPC) handleCall(msg Message) { argsCount := methodType.NumIn() if len(msg.Params) != argsCount { - k.sendResponse(msg.Id, nil, fmt.Errorf("argument count mismatch: expected %d, got %d", argsCount, len(msg.Params))) + ipc.sendResponse(msg.Id, nil, fmt.Errorf("argument count mismatch: expected %d, got %d", argsCount, len(msg.Params))) return } @@ -233,20 +111,20 @@ func (k *KittenIPC) handleCall(msg Message) { res = append(res, resVal) } - k.sendResponse(msg.Id, res, resErr.Interface().(error)) + ipc.sendResponse(msg.Id, res, resErr.Interface().(error)) } -func (k *KittenIPC) handleResponse(msg Message) { +func (ipc *ipcCommon) handleResponse(msg Message) { - k.mu.Lock() - ch, ok := k.pendingCalls[msg.Id] + ipc.mu.Lock() + ch, ok := ipc.pendingCalls[msg.Id] if ok { - delete(k.pendingCalls, msg.Id) + delete(ipc.pendingCalls, msg.Id) } - k.mu.Unlock() + ipc.mu.Unlock() if !ok { - k.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id)) + ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id)) return } @@ -259,7 +137,7 @@ func (k *KittenIPC) handleResponse(msg Message) { close(ch) } -func (k *KittenIPC) sendResponse(id int64, result []any, err error) { +func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) { msg := Message{ Type: MsgResponse, Id: id, @@ -270,12 +148,12 @@ func (k *KittenIPC) sendResponse(id int64, result []any, err error) { msg.Error = err.Error() } - if err := k.sendMsg(msg); err != nil { - k.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err)) + if err := ipc.sendMsg(msg); err != nil { + ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err)) } } -func (k *KittenIPC) sendMsg(msg Message) error { +func (ipc *ipcCommon) sendMsg(msg Message) error { data, err := json.Marshal(msg) if err != nil { @@ -284,20 +162,20 @@ func (k *KittenIPC) sendMsg(msg Message) error { data = append(data, '\n') - if _, err := k.conn.Write(data); err != nil { + if _, err := ipc.conn.Write(data); err != nil { return fmt.Errorf("write message: %w", err) } return nil } -func (k *KittenIPC) Call(method string, params ...any) ([]any, error) { - k.mu.Lock() - id := k.nextId - k.nextId++ +func (ipc *ipcCommon) Call(method string, params ...any) ([]any, error) { + ipc.mu.Lock() + id := ipc.nextId + ipc.nextId++ resChan := make(chan callResult, 1) - k.pendingCalls[id] = resChan - k.mu.Unlock() + ipc.pendingCalls[id] = resChan + ipc.mu.Unlock() msg := Message{ Type: MsgCall, @@ -306,10 +184,10 @@ func (k *KittenIPC) Call(method string, params ...any) ([]any, error) { Params: params, } - if err := k.sendMsg(msg); err != nil { - k.mu.Lock() - delete(k.pendingCalls, id) - k.mu.Unlock() + if err := ipc.sendMsg(msg); err != nil { + ipc.mu.Lock() + delete(ipc.pendingCalls, id) + ipc.mu.Unlock() return nil, fmt.Errorf("send call: %w", err) } @@ -317,38 +195,111 @@ func (k *KittenIPC) Call(method string, params ...any) ([]any, error) { return result.result, result.err } -func (k *KittenIPC) raiseErr(err error) { +func (ipc *ipcCommon) raiseErr(err error) { select { - case k.errCh <- err: + case ipc.errCh <- err: default: } } -func (k *KittenIPC) closeSock() error { - if err := k.listener.Close(); err != nil { +type ParentIPC struct { + *ipcCommon + cmd *exec.Cmd + listener net.Listener +} + +func NewParent(cmd *exec.Cmd, localApi any) (*ParentIPC, error) { + p := ParentIPC{ + ipcCommon: &ipcCommon{ + localApi: localApi, + pendingCalls: make(map[int64]chan callResult), + errCh: make(chan error, 1), + socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid())), + }, + cmd: cmd, + } + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if slices.Contains(cmd.Args, ipcSocketArg) { + return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg) + } + cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath) + + p.errCh = make(chan error, 1) + + return &p, nil +} + +func (p *ParentIPC) Start() error { + _ = os.Remove(p.socketPath) + listener, err := net.Listen("unix", p.socketPath) + if err != nil { + return fmt.Errorf("listen unix socket: %w", err) + } + p.listener = listener + defer p.closeSock() + + err = p.cmd.Start() + if err != nil { + return fmt.Errorf("cmd start: %w", err) + } + + if err := p.acceptConn(); err != nil { + return fmt.Errorf("accept connection: %w", err) + } + + return nil +} + +func (p *ParentIPC) acceptConn() error { + const acceptTimeout = time.Second * 10 + + res := make(chan mo.Result[net.Conn], 1) + go func() { + conn, err := p.listener.Accept() + if err != nil { + res <- mo.Err[net.Conn](err) + } else { + res <- mo.Ok[net.Conn](conn) + } + close(res) + }() + + select { + case <-time.After(acceptTimeout): + _ = p.cmd.Process.Kill() + return fmt.Errorf("accept timeout") + case res := <-res: + if res.IsError() { + _ = p.cmd.Process.Kill() + return fmt.Errorf("accept: %w", res.Error()) + } + p.conn = res.MustGet() + p.startRcvData() + } + return nil +} + +func (p *ParentIPC) closeSock() error { + if err := p.listener.Close(); err != nil { return fmt.Errorf("close socket listener: %w", err) } return nil } -func (k *KittenIPC) Wait() error { - if k.mode == modeParent { - return k.waitParent() - } - return k.waitChild() -} - -func (k *KittenIPC) waitParent() error { +func (p *ParentIPC) Wait() error { waitErrCh := make(chan error, 1) go func() { - waitErrCh <- k.cmd.Wait() + waitErrCh <- p.cmd.Wait() }() select { - case err := <-k.errCh: + case err := <-p.errCh: runtimeErr := fmt.Errorf("runtime error: %w", err) - killErr := k.cmd.Process.Kill() + killErr := p.cmd.Process.Kill() return mergeErr(runtimeErr, killErr) case err := <-waitErrCh: if err != nil { @@ -359,8 +310,42 @@ func (k *KittenIPC) waitParent() error { return nil } -func (k *KittenIPC) waitChild() error { - err := <-k.errCh +type ChildIPC struct { + *ipcCommon +} + +func NewChild(localApi any) (*ChildIPC, error) { + c := ChildIPC{ + ipcCommon: &ipcCommon{ + localApi: localApi, + pendingCalls: make(map[int64]chan callResult), + errCh: make(chan error, 1), + }, + } + + socketPath := flag.String("ipc-socket", "", "Path to IPC socket") + flag.Parse() + + if *socketPath == "" { + return nil, fmt.Errorf("ipc socket path is missing") + } + c.socketPath = *socketPath + + return &c, nil +} + +func (c *ChildIPC) Start() error { + conn, err := net.Dial("unix", c.socketPath) + if err != nil { + return fmt.Errorf("connect to parent socket: %w", err) + } + c.conn = conn + c.startRcvData() + return nil +} + +func (c *ChildIPC) Wait() error { + err := <-c.errCh if err != nil { return fmt.Errorf("ipc error: %w", err) }