Skip to content

Commit

Permalink
[tvm4j] support kNDArrayContainer (apache#1510)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and sergei-mironov committed Aug 8, 2018
1 parent d3e8d69 commit 2ba20f9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
7 changes: 5 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ public Function pushArg(String arg) {
* @return this
*/
public Function pushArg(NDArrayBase arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
return this;
}

Expand Down Expand Up @@ -247,7 +248,9 @@ private static void pushArgToStack(Object arg) {
} else if (arg instanceof byte[]) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) {
Base._LIB.tvmFuncPushArgHandle(((NDArrayBase) arg).handle, TypeCode.ARRAY_HANDLE.id);
NDArrayBase nd = (NDArrayBase) arg;
int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
} else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/ml/dmlc/tvm/TypeCode.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
public enum TypeCode {
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
FUNC_HANDLE(10), STR(11), BYTES(12);
FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13);

public final int id;

Expand Down
10 changes: 6 additions & 4 deletions jvm/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ jobject newFunction(JNIEnv *env, jlong value) {
return object;
}

jobject newNDArray(JNIEnv *env, jlong value) {
jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) {
jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
jobject object = env->NewObject(cls, constructor, handle, isview);
env->DeleteLocalRef(cls);
return object;
}
Expand Down Expand Up @@ -181,7 +181,9 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
case kFuncHandle:
return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
case kArrayHandle:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle));
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true);
case kNDArrayContainer:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false);
case kStr:
return newTVMValueString(env, value.v_str);
case kBytes:
Expand Down

0 comments on commit 2ba20f9

Please sign in to comment.