2025-10-23 17:58:09 +03:00

382 lines
7.2 KiB
Go

package kittenipc
import (
"bufio"
"flag"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"reflect"
"slices"
"sync"
"time"
"github.com/go-json-experiment/json"
"github.com/samber/mo"
)
const ipcSocketArg = "--ipc-socket"
type StdioMode int
type ipcMode int
const (
modeParent ipcMode = 1
modeChild ipcMode = 2
)
type KittenIPC struct {
mode ipcMode
cmd *exec.Cmd
localApi any
socketPath string
listener net.Listener
conn net.Conn
errCh chan error
nextId int64
pendingCalls map[int64]chan callResult
mu sync.Mutex
}
func NewParent(cmd *exec.Cmd, localApi any) (*KittenIPC, error) {
k := KittenIPC{
mode: modeParent,
cmd: cmd,
localApi: localApi,
pendingCalls: make(map[int64]chan callResult),
errCh: make(chan error, 1),
}
k.socketPath = filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid()))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if slices.Contains(cmd.Args, ipcSocketArg) {
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
}
cmd.Args = append(cmd.Args, ipcSocketArg, k.socketPath)
k.errCh = make(chan error, 1)
return &k, nil
}
func NewChild(localApi any) (*KittenIPC, error) {
k := KittenIPC{
mode: modeChild,
localApi: localApi,
pendingCalls: make(map[int64]chan callResult),
errCh: make(chan error, 1),
}
socketPath := flag.String("ipc-socket", "", "Path to IPC socket")
flag.Parse()
if *socketPath == "" {
return nil, fmt.Errorf("ipc socket path is missing")
}
k.socketPath = *socketPath
return &k, nil
}
func (k *KittenIPC) Start() error {
if k.mode == modeParent {
return k.startParent()
}
return k.startChild()
}
func (k *KittenIPC) startParent() error {
_ = os.Remove(k.socketPath)
listener, err := net.Listen("unix", k.socketPath)
if err != nil {
return fmt.Errorf("listen unix socket: %w", err)
}
k.listener = listener
defer k.closeSock()
err = k.cmd.Start()
if err != nil {
return fmt.Errorf("cmd start: %w", err)
}
if err := k.acceptConn(); err != nil {
return fmt.Errorf("accept connection: %w", err)
}
return nil
}
func (k *KittenIPC) startChild() error {
conn, err := net.Dial("unix", k.socketPath)
if err != nil {
return fmt.Errorf("connect to parent socket: %w", err)
}
k.conn = conn
k.startRcvData()
return nil
}
func (k *KittenIPC) acceptConn() error {
const acceptTimeout = time.Second * 10
res := make(chan mo.Result[net.Conn], 1)
go func() {
conn, err := k.listener.Accept()
if err != nil {
res <- mo.Err[net.Conn](err)
} else {
res <- mo.Ok[net.Conn](conn)
}
close(res)
}()
select {
case <-time.After(acceptTimeout):
_ = k.cmd.Process.Kill()
return fmt.Errorf("accept timeout")
case res := <-res:
if res.IsError() {
_ = k.cmd.Process.Kill()
return fmt.Errorf("accept: %w", res.Error())
}
k.conn = res.MustGet()
k.startRcvData()
}
return nil
}
type MsgType int
const (
MsgCall MsgType = 1
MsgResponse MsgType = 2
)
type Message struct {
Type MsgType `json:"type"`
Id int64 `json:"id"`
Method string `json:"method"`
Params []any `json:"params"`
Result []any `json:"result"`
Error string `json:"error"`
}
type callResult struct {
result []any
err error
}
func (k *KittenIPC) startRcvData() {
scn := bufio.NewScanner(k.conn)
for scn.Scan() {
var msg Message
if err := json.Unmarshal(scn.Bytes(), &msg); err != nil {
k.raiseErr(fmt.Errorf("unmarshal message: %w", err))
break
}
k.processMsg(msg)
}
if err := scn.Err(); err != nil {
k.raiseErr(err)
}
}
func (k *KittenIPC) processMsg(msg Message) {
switch msg.Type {
case MsgCall:
k.handleCall(msg)
case MsgResponse:
k.handleResponse(msg)
}
}
func (k *KittenIPC) handleCall(msg Message) {
if k.localApi == nil {
k.sendResponse(msg.Id, nil, fmt.Errorf("remote side does not accept ipc calls"))
}
localApi := reflect.ValueOf(k.localApi)
method := localApi.MethodByName(msg.Method)
if !method.IsValid() {
k.sendResponse(msg.Id, nil, fmt.Errorf("method not found: %s", msg.Method))
return
}
methodType := method.Type()
argsCount := methodType.NumIn()
if len(msg.Params) != argsCount {
k.sendResponse(msg.Id, nil, fmt.Errorf("argument count mismatch: expected %d, got %d", argsCount, len(msg.Params)))
return
}
var args []reflect.Value
for _, param := range msg.Params {
args = append(args, reflect.ValueOf(param))
}
results := method.Call(args)
resVals := results[0 : len(results)-1]
resErr := results[len(results)-1]
var res []any
for _, resVal := range resVals {
res = append(res, resVal)
}
k.sendResponse(msg.Id, res, resErr.Interface().(error))
}
func (k *KittenIPC) handleResponse(msg Message) {
k.mu.Lock()
ch, ok := k.pendingCalls[msg.Id]
if ok {
delete(k.pendingCalls, msg.Id)
}
k.mu.Unlock()
if !ok {
k.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
return
}
var err error
if msg.Error != "" {
err = fmt.Errorf("remote error: %s", msg.Error)
}
ch <- callResult{result: msg.Result, err: err}
close(ch)
}
func (k *KittenIPC) sendResponse(id int64, result []any, err error) {
msg := Message{
Type: MsgResponse,
Id: id,
Result: result,
}
if err != nil {
msg.Error = err.Error()
}
if err := k.sendMsg(msg); err != nil {
k.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
}
}
func (k *KittenIPC) sendMsg(msg Message) error {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal message: %w", err)
}
data = append(data, '\n')
if _, err := k.conn.Write(data); err != nil {
return fmt.Errorf("write message: %w", err)
}
return nil
}
func (k *KittenIPC) Call(method string, params ...any) ([]any, error) {
k.mu.Lock()
id := k.nextId
k.nextId++
resChan := make(chan callResult, 1)
k.pendingCalls[id] = resChan
k.mu.Unlock()
msg := Message{
Type: MsgCall,
Id: id,
Method: method,
Params: params,
}
if err := k.sendMsg(msg); err != nil {
k.mu.Lock()
delete(k.pendingCalls, id)
k.mu.Unlock()
return nil, fmt.Errorf("send call: %w", err)
}
result := <-resChan
return result.result, result.err
}
func (k *KittenIPC) raiseErr(err error) {
select {
case k.errCh <- err:
default:
}
}
func (k *KittenIPC) closeSock() error {
if err := k.listener.Close(); err != nil {
return fmt.Errorf("close socket listener: %w", err)
}
return nil
}
func (k *KittenIPC) Wait() error {
if k.mode == modeParent {
return k.waitParent()
}
return k.waitChild()
}
func (k *KittenIPC) waitParent() error {
waitErrCh := make(chan error, 1)
go func() {
waitErrCh <- k.cmd.Wait()
}()
select {
case err := <-k.errCh:
runtimeErr := fmt.Errorf("runtime error: %w", err)
killErr := k.cmd.Process.Kill()
return mergeErr(runtimeErr, killErr)
case err := <-waitErrCh:
if err != nil {
return fmt.Errorf("cmd wait: %w", err)
}
}
return nil
}
func (k *KittenIPC) waitChild() error {
err := <-k.errCh
if err != nil {
return fmt.Errorf("ipc error: %w", err)
}
return nil
}
func mergeErr(errs ...error) (ret error) {
for _, err := range errs {
if err != nil {
if ret == nil {
ret = err
} else {
ret = fmt.Errorf("%w; %w", ret, err)
}
}
}
return
}