zml: small cleanup
- Add more scatterSlices test cases. - Replace helpers.mapTensors with zml.meta.map. - Fix shape handling when a for loop is fully unrolled. - Allow zml.Tensor.pad to accept i64 for dimension compatibility. - Enable arrays of tensors inside model structs. - Split Buffer.asViewOf into asViewOfHostBuffer and asViewOfDeviceBuffer.
This commit is contained in:
parent
f00538667e
commit
c30aa018dc
24
zml/aio.zig
24
zml/aio.zig
@ -383,6 +383,18 @@ fn _populateStruct(
|
||||
return false;
|
||||
}
|
||||
},
|
||||
.Array => |arr_info| {
|
||||
for (obj, 0..) |*value, i| {
|
||||
try prefix_builder.pushDigit(allocator, i);
|
||||
defer prefix_builder.pop();
|
||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
|
||||
if (!found) {
|
||||
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(arr_info.child) });
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
.Struct => |struct_info| {
|
||||
var partial_struct = false;
|
||||
inline for (struct_info.fields) |field| {
|
||||
@ -594,7 +606,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
||||
} else {
|
||||
return error.BufferNotFound;
|
||||
};
|
||||
}
|
||||
} else if (T == zml.Shape) return;
|
||||
|
||||
switch (type_info) {
|
||||
.Pointer => |ptr_info| {
|
||||
@ -605,8 +617,16 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
||||
|
||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform);
|
||||
}
|
||||
} else return error.TypeNotSupported;
|
||||
} else zml.meta.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T});
|
||||
},
|
||||
.Array => {
|
||||
for (obj, 0..) |*value, i| {
|
||||
try prefix_builder.pushDigit(allocator, i);
|
||||
defer prefix_builder.pop();
|
||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform);
|
||||
}
|
||||
},
|
||||
|
||||
.Struct => |struct_info| {
|
||||
inline for (struct_info.fields) |field| {
|
||||
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
||||
|
||||
@ -140,14 +140,20 @@ pub const Buffer = struct {
|
||||
try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y);
|
||||
}
|
||||
|
||||
/// Creates a Buffer as a view of memory visible from the device,
|
||||
/// Creates a Buffer as a view of host memory visible from the device,
|
||||
/// thus avoiding a copy.
|
||||
///
|
||||
/// On CUDA, it also allows you to specify a host allocated slice as they seem to be accessible.
|
||||
/// Be careful though, as it requires a specific alignment.
|
||||
/// Also note that it might not work on all platforms,
|
||||
/// could lead to crashes and is considerably slower.
|
||||
pub fn asViewOf(platform: Platform, buf: HostBuffer) !Buffer {
|
||||
/// Be careful though, as it requires a specific alignment
|
||||
/// and it might not work on all platforms,
|
||||
/// could lead to crashes and operations on the buffer will be slower.
|
||||
/// Tested on Cuda 12.4.
|
||||
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) !Buffer {
|
||||
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr)));
|
||||
}
|
||||
|
||||
/// Creates a Buffer from a pointer into device memory.
|
||||
/// This allows to interface with other libraries producing buffers.
|
||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) !Buffer {
|
||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
||||
for (0..Shape.MAX_RANK) |i| {
|
||||
@ -156,26 +162,28 @@ pub const Buffer = struct {
|
||||
break :blk res;
|
||||
};
|
||||
|
||||
const device_bytes: [*]u8 = @ptrCast(device_data);
|
||||
const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
||||
.data = buf.data,
|
||||
.element_type = bufferTypeFromDtype(buf.shape().dtype()),
|
||||
.dims = buf.shape().dims(),
|
||||
// TODO: split in shards
|
||||
.data = device_bytes[0..shape_.byteSize()],
|
||||
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
||||
.dims = shape_.dims(),
|
||||
// TODO: exposes sharding in the API.
|
||||
.device = platform.getDevices()[0],
|
||||
.layout = .{
|
||||
.Tiled = .{
|
||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - buf.shape().rank() ..],
|
||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
||||
.tile_dims = &.{},
|
||||
.tile_dims_sizes = &.{},
|
||||
},
|
||||
},
|
||||
.stream = @bitCast(@as(usize, @intFromPtr(stream))),
|
||||
});
|
||||
|
||||
var shards: Shards = .{};
|
||||
shards.appendAssumeCapacity(pjrt_buffer);
|
||||
return .{
|
||||
._api = platform.pjrt_api,
|
||||
._shape = buf.shape(),
|
||||
._shape = shape_,
|
||||
._shards = shards,
|
||||
};
|
||||
}
|
||||
|
||||
@ -213,6 +213,7 @@ test BFloat16 {
|
||||
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 65 });
|
||||
try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 });
|
||||
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf().neg()), [_]u8{ 0x80, 0xff });
|
||||
try std.testing.expectEqual(BFloat16.inf(), BFloat16.fromF32(std.math.inf(f32)));
|
||||
|
||||
const lossless = [_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) };
|
||||
for (&lossless) |v| {
|
||||
|
||||
@ -139,34 +139,3 @@ fn ShapeStruct(comptime dims: anytype) type {
|
||||
.is_tuple = false,
|
||||
} });
|
||||
}
|
||||
|
||||
/// Return a new struct with all tensors replaced by the output of the given function.
|
||||
pub fn mapTensors(func: anytype, v: anytype, args: anytype) @TypeOf(v) {
|
||||
const T = @TypeOf(v);
|
||||
const type_info = @typeInfo(T);
|
||||
if (T == Tensor) return @call(.auto, func, .{v} ++ args);
|
||||
|
||||
return switch (type_info) {
|
||||
.Pointer => @compileError("mapTensors only accept by value arguments. Received: " ++ @typeName(T)),
|
||||
.Struct => |struct_info| {
|
||||
var copy: T = v;
|
||||
inline for (struct_info.fields) |feeld| {
|
||||
if (feeld.is_comptime) continue;
|
||||
if (@typeInfo(feeld.type) == .Pointer) {
|
||||
@compileError("mapTensors doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(T));
|
||||
}
|
||||
@field(copy, feeld.name) = mapTensors(func, @field(v, feeld.name), args);
|
||||
}
|
||||
return copy;
|
||||
},
|
||||
.Array => {
|
||||
var res: T = undefined;
|
||||
for (v, &res) |item, *r| {
|
||||
r.* = mapTensors(func, item, args);
|
||||
}
|
||||
return res;
|
||||
},
|
||||
.Union, .Optional => @compileError("mapTensors doesn't yet support " ++ @typeName(T)),
|
||||
else => v,
|
||||
};
|
||||
}
|
||||
|
||||
12
zml/ops.zig
12
zml/ops.zig
@ -303,7 +303,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
||||
return Tensor.constant(shape, x.dtype().zero());
|
||||
}
|
||||
|
||||
fn wrapFirstStep(x: Tensor, tag_: @TypeOf(step_tag)) Tensor {
|
||||
fn wrapFirstStep(tag_: @TypeOf(step_tag), x: Tensor) Tensor {
|
||||
var shape = x.shape();
|
||||
shape._dims.insert(0, 1) catch unreachable;
|
||||
shape._tags.insert(0, tag_) catch unreachable;
|
||||
@ -315,13 +315,14 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
||||
// it's only used to infer the output shapes.
|
||||
const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) });
|
||||
log.debug("for_ first_step: {}", .{first_step});
|
||||
const allocator = CompilationContext.current()._allocator;
|
||||
// Optimize for small num reps
|
||||
if (num_steps == 1) {
|
||||
// return helpers.mapTensors(ForBlk.wrapFirstStep, first_step, .{ step_tag });
|
||||
return first_step;
|
||||
var res = first_step;
|
||||
meta.mapAlloc(ForBlk.wrapFirstStep, allocator, step_tag, first_step, &res) catch unreachable;
|
||||
return res;
|
||||
}
|
||||
|
||||
const allocator = CompilationContext.current()._allocator;
|
||||
if (num_steps <= 4) {
|
||||
var steps: [4]S.Return = undefined;
|
||||
steps[0] = first_step;
|
||||
@ -368,16 +369,19 @@ test for_ {
|
||||
// Just one baby step
|
||||
{
|
||||
const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{1});
|
||||
try zml.testing.expectEqualShapes(Shape.init(.{1}, .f32), squares.shape());
|
||||
try std.testing.expectEqual(0, squares.getValue(f32));
|
||||
}
|
||||
// Wow 4 in rows !
|
||||
{
|
||||
const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{4});
|
||||
try zml.testing.expectEqualShapes(Shape.init(.{4}, .f32), squares.shape());
|
||||
try std.testing.expectEqual([_]f32{ 0, 1, 4, 9 }, try squares.getValue([4]f32));
|
||||
}
|
||||
// AGI is coming, computing 10 squares as it's nothing.
|
||||
{
|
||||
const squares = try zml.testing.compileAndCall(platform, Squares.forward, .{10});
|
||||
try zml.testing.expectEqualShapes(Shape.init(.{10}, .f32), squares.shape());
|
||||
try std.testing.expectEqual(
|
||||
[_]f32{ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81 },
|
||||
try squares.getValue([10]f32),
|
||||
|
||||
@ -1908,9 +1908,9 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
pub const Pad = struct {
|
||||
low: i32 = 0,
|
||||
high: i32 = 0,
|
||||
interior: i32 = 0,
|
||||
low: i64 = 0,
|
||||
high: i64 = 0,
|
||||
interior: i64 = 0,
|
||||
};
|
||||
|
||||
/// Pads the input Tensor with the given values.
|
||||
@ -2542,6 +2542,26 @@ pub const Tensor = struct {
|
||||
try std.testing.expect(a.shape().eql(result.shape()));
|
||||
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
|
||||
}
|
||||
// Test with setting individual values (no batching)
|
||||
{
|
||||
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
||||
const a = try zml.Buffer.from(platform, a_host);
|
||||
defer a.deinit();
|
||||
a_host.deinit(std.testing.allocator);
|
||||
|
||||
const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{2}, .{7} });
|
||||
const updates = try zml.Buffer.fromArray(platform, [2]i32{ 20, 70 });
|
||||
|
||||
const expected = [9]i32{ 0, 1, 22, 3, 4, 5, 6, 77, 8 };
|
||||
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{
|
||||
a,
|
||||
a.shape().axes(.{0}),
|
||||
scatter_indices.withTags(.{ .n, .coord }),
|
||||
updates.withTags(.{.n}),
|
||||
});
|
||||
try std.testing.expect(a.shape().eql(result.shape()));
|
||||
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
|
||||
}
|
||||
{
|
||||
// Test with actual values and batching along axis .a
|
||||
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user