diff --git a/lib/golang/lib.go b/lib/golang/lib.go index 25af20a..01dbe49 100644 --- a/lib/golang/lib.go +++ b/lib/golang/lib.go @@ -127,33 +127,36 @@ func (ipc *ipcCommon) handleCall(msg Message) { for i, arg := range msg.Args { paramType := method.Type().In(i) argType := reflect.TypeOf(arg) - - // JSON decodes any number to float64. If we need int, we should check and convert - if paramType.Kind() == reflect.Int && argType.Kind() == reflect.Float64 { - floatArg := arg.(float64) - if float64(int64(floatArg)) == floatArg && !paramType.OverflowInt(int64(floatArg)) { - arg = arg.(int) - } - } - - args = append(args, reflect.ValueOf(paramType)) + arg = ipc.convType(paramType, argType, arg) + args = append(args, reflect.ValueOf(arg)) } - results := method.Call(args) - resVals := results[0 : len(results)-1] - resErrVal := results[len(results)-1] + allResultVals := method.Call(args) + retResultVals := allResultVals[0 : len(allResultVals)-1] + errResultVals := allResultVals[len(allResultVals)-1] - var res []any - for _, resVal := range resVals { - res = append(res, resVal.Interface()) + var results []any + for _, resVal := range retResultVals { + results = append(results, resVal.Interface()) } var resErr error - if !resErrVal.IsNil() { - resErr = resErrVal.Interface().(error) + if !errResultVals.IsNil() { + resErr = errResultVals.Interface().(error) } - ipc.sendResponse(msg.Id, res, resErr) + ipc.sendResponse(msg.Id, results, resErr) +} + +func (ipc *ipcCommon) convType(needType reflect.Type, gotType reflect.Type, arg any) any { + // JSON decodes any number to float64. If we need int, we should check and convert + if needType.Kind() == reflect.Int && gotType.Kind() == reflect.Float64 { + floatArg := arg.(float64) + if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) { + arg = int(floatArg) + } + } + return arg } func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) {