diff --git a/zml/tensor.zig b/zml/tensor.zig index b248ea2..79c2f2c 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -206,8 +206,6 @@ pub const Tensor = struct { const ctx = self.getContext(); const mlir_ctx = ctx.mlirCtx(); if (ctx.target() == .cpu) return self; - var res = self; - res._output_memory_kind = kind; const memory_kind = @tagName(kind.toPjrtMemory()); @@ -223,7 +221,9 @@ pub const Tensor = struct { .api_version = .original, }, &.{self.value().getType()}, mlir_ctx.location(@src())); - return _result(res._shape, op.result(0)); + var res = _result(self._shape, op.result(0)); + res._output_memory_kind = kind; + return res; }, .buffer_id => { var res = self;