diff --git a/unix/modjni.c b/unix/modjni.c
index eeed389ba2748876b87f73eb0bd43ec0a3f055d1..27470ae02bca458ec5081cb272bc832fe1026fd0 100644
--- a/unix/modjni.c
+++ b/unix/modjni.c
@@ -361,8 +361,8 @@ STATIC bool py2jvalue(const char **jtypesig, mp_obj_t arg, jvalue *out) {
             return false;
         }
     } else if (type == &jobject_type) {
-        printf("TODO: Check java arg type!!\n");
         bool is_object = false;
+        const char *expected_type = arg_type;
         while (1) {
             if (isalpha(*arg_type)) {
             } else if (*arg_type == '.') {
@@ -376,6 +376,14 @@ STATIC bool py2jvalue(const char **jtypesig, mp_obj_t arg, jvalue *out) {
             return false;
         }
         mp_obj_jobject_t *jo = arg;
+        if (!MATCH(expected_type, "java.lang.Object")) {
+            char class_name[64];
+            get_jclass_name(jo->obj, class_name);
+            //printf("Arg class: %s\n", class_name);
+            if (strcmp(class_name, expected_type) != 0) {
+                return false;
+            }
+        }
         out->l = jo->obj;
     } else if (type == &mp_type_bool) {
         if (IMATCH(arg_type, "boolean")) {