This commit is contained in:
Egor Aristov 2025-10-23 13:14:43 +03:00
parent 474e667e9a
commit 0a582a2eb0
Signed by: egor3f
GPG Key ID: 40482A264AAEC85F

View File

@ -7,7 +7,9 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"reflect"
"slices" "slices"
"sync"
"time" "time"
"github.com/go-json-experiment/json" "github.com/go-json-experiment/json"
@ -22,17 +24,24 @@ type Config struct {
type KittenIPC struct { type KittenIPC struct {
cmd *exec.Cmd cmd *exec.Cmd
cfg Config cfg Config
localApi any
socketPath string socketPath string
listener net.Listener listener net.Listener
conn net.Conn conn net.Conn
errCh chan error errCh chan error
nextId int64
pendingCalls map[int64]chan callResult
mu sync.Mutex
} }
func New(cmd *exec.Cmd, api any, cfg Config) (*KittenIPC, error) { func New(cmd *exec.Cmd, localApi any, cfg Config) (*KittenIPC, error) {
k := KittenIPC{ k := KittenIPC{
cmd: cmd, cmd: cmd,
cfg: cfg, cfg: cfg,
localApi: localApi,
pendingCalls: make(map[int64]chan callResult),
} }
k.socketPath = filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid())) k.socketPath = filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid()))
@ -117,6 +126,11 @@ type Message struct {
Error string `json:"error"` Error string `json:"error"`
} }
type callResult struct {
result []any
err error
}
func (k *KittenIPC) startRcvData() { func (k *KittenIPC) startRcvData() {
scn := bufio.NewScanner(k.conn) scn := bufio.NewScanner(k.conn)
for scn.Scan() { for scn.Scan() {
@ -133,11 +147,131 @@ func (k *KittenIPC) startRcvData() {
} }
func (k *KittenIPC) processMsg(msg Message) { func (k *KittenIPC) processMsg(msg Message) {
switch msg.Type {
case MsgCall:
k.handleCall(msg)
case MsgResponse:
k.handleResponse(msg)
}
} }
func (k *KittenIPC) Call() { func (k *KittenIPC) handleCall(msg Message) {
if k.localApi == nil {
k.sendResponse(msg.Id, nil, fmt.Errorf("remote side does not accept ipc calls"))
}
localApi := reflect.ValueOf(k.localApi)
method := localApi.MethodByName(msg.Method)
if !method.IsValid() {
k.sendResponse(msg.Id, nil, fmt.Errorf("method not found: %s", msg.Method))
return
}
methodType := method.Type()
argsCount := methodType.NumIn()
if len(msg.Params) != argsCount {
k.sendResponse(msg.Id, nil, fmt.Errorf("argument count mismatch: expected %d, got %d", argsCount, len(msg.Params)))
return
}
var args []reflect.Value
for _, param := range msg.Params {
args = append(args, reflect.ValueOf(param))
}
results := method.Call(args)
resVals := results[0 : len(results)-1]
resErr := results[len(results)-1]
var res []any
for _, resVal := range resVals {
res = append(res, resVal)
}
k.sendResponse(msg.Id, res, resErr.Interface().(error))
}
func (k *KittenIPC) handleResponse(msg Message) {
k.mu.Lock()
ch, ok := k.pendingCalls[msg.Id]
if ok {
delete(k.pendingCalls, msg.Id)
}
k.mu.Unlock()
if !ok {
k.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
return
}
var err error
if msg.Error != "" {
err = fmt.Errorf("remote error: %s", msg.Error)
}
ch <- callResult{result: msg.Result, err: err}
close(ch)
}
func (k *KittenIPC) sendResponse(id int64, result []any, err error) {
msg := Message{
Type: MsgResponse,
Id: id,
Result: result,
}
if err != nil {
msg.Error = err.Error()
}
if err := k.sendMsg(msg); err != nil {
k.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
}
}
func (k *KittenIPC) sendMsg(msg Message) error {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal message: %w", err)
}
data = append(data, '\n')
if _, err := k.conn.Write(data); err != nil {
return fmt.Errorf("write message: %w", err)
}
return nil
}
func (k *KittenIPC) Call(method string, params ...any) ([]any, error) {
k.mu.Lock()
id := k.nextId
k.nextId++
resChan := make(chan callResult, 1)
k.pendingCalls[id] = resChan
k.mu.Unlock()
msg := Message{
Type: MsgCall,
Id: id,
Method: method,
Params: params,
}
if err := k.sendMsg(msg); err != nil {
k.mu.Lock()
delete(k.pendingCalls, id)
k.mu.Unlock()
return nil, fmt.Errorf("send call: %w", err)
}
result := <-resChan
return result.result, result.err
} }
func (k *KittenIPC) raiseErr(err error) { func (k *KittenIPC) raiseErr(err error) {