diff --git a/unix/modjni.c b/unix/modjni.c
index 5fc4dca3561a129434fcad1d4c262521e7a946a3..001fa097ab23ab183d9a5a22e00e2732f9a8ae39 100644
--- a/unix/modjni.c
+++ b/unix/modjni.c
@@ -40,6 +40,7 @@
 
 static JavaVM *jvm;
 static JNIEnv *env;
+static jclass Class_class;
 static jclass String_class;
 static jmethodID Class_getField_mid;
 static jmethodID Class_getMethods_mid;
@@ -51,6 +52,7 @@ STATIC const mp_obj_type_t jobject_type;
 STATIC const mp_obj_type_t jmethod_type;
 
 STATIC mp_obj_t new_jobject(jobject jo);
+STATIC mp_obj_t new_jclass(jclass jc);
 STATIC mp_obj_t call_method(jobject obj, const char *name, jarray methods, bool is_constr, mp_uint_t n_args, const mp_obj_t *args);
 
 typedef struct _mp_obj_jclass_t {
@@ -134,6 +136,12 @@ STATIC const mp_obj_type_t jclass_type = {
     .locals_dict = (mp_obj_t)&jclass_locals_dict,
 };
 
+STATIC mp_obj_t new_jclass(jclass jc) {
+    mp_obj_jclass_t *o = m_new_obj(mp_obj_jclass_t);
+    o->base.type = &jclass_type;
+    o->cls = jc;
+    return o;
+}
 
 // jobject
 
@@ -243,6 +251,8 @@ ret_string:;
                 // Non-primitive, object type
                 if (JJ(IsInstanceOf, arg, String_class)) {
                     goto ret_string;
+                } else if (JJ(IsInstanceOf, arg, Class_class)) {
+                    return new_jclass(arg);
                 } else {
                     return new_jobject(arg);
                 }
@@ -379,15 +389,15 @@ STATIC void create_jvm() {
         nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_OSError, "unable to create JVM"));
     }
 
-    jclass class_class = JJ(FindClass, "java/lang/Class");
+    Class_class = JJ(FindClass, "java/lang/Class");
     jclass method_class = JJ(FindClass, "java/lang/reflect/Method");
     String_class = JJ(FindClass, "java/lang/String");
 
-    Class_getField_mid = (*env)->GetMethodID(env, class_class, "getField",
+    Class_getField_mid = (*env)->GetMethodID(env, Class_class, "getField",
                                      "(Ljava/lang/String;)Ljava/lang/reflect/Field;");
-    Class_getMethods_mid = (*env)->GetMethodID(env, class_class, "getMethods",
+    Class_getMethods_mid = (*env)->GetMethodID(env, Class_class, "getMethods",
                                      "()[Ljava/lang/reflect/Method;");
-    Class_getConstructors_mid = (*env)->GetMethodID(env, class_class, "getConstructors",
+    Class_getConstructors_mid = (*env)->GetMethodID(env, Class_class, "getConstructors",
                                      "()[Ljava/lang/reflect/Constructor;");
     Method_getName_mid = (*env)->GetMethodID(env, method_class, "getName",
                                      "()Ljava/lang/String;");