Remove usingnamespace from MLIR.
This commit is contained in:
parent
f8ab0d7b2a
commit
0a2ab7c8cb
@ -73,7 +73,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m
|
||||
.operands = &.{ lhs, rhs },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
|
||||
.{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -103,7 +103,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m
|
||||
.operands = &.{ lhs, rhs },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
|
||||
.{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
|
||||
@ -127,10 +127,10 @@ pub const DotPrecision = union(enum) {
|
||||
// When we specify the dot algorithm, we should not specify the precision.
|
||||
.algorithm => .DEFAULT,
|
||||
});
|
||||
return precision.as(mlir.Attribute);
|
||||
return precision.asAttr();
|
||||
}
|
||||
|
||||
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
|
||||
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.RankedTensorType) ?mlir.Attribute {
|
||||
return switch (self) {
|
||||
.algorithm => |algo| algo.asAttr(ctx, operand_type),
|
||||
else => null,
|
||||
@ -156,15 +156,14 @@ pub const DotAlgorithm = struct {
|
||||
.allow_imprecise_accumulation = false,
|
||||
};
|
||||
|
||||
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
|
||||
const tensor_type = operand_type.as(mlir.RankedTensorType);
|
||||
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, tensor_type: mlir.RankedTensorType) mlir.Attribute {
|
||||
const elem_type = tensor_type.getElementType();
|
||||
|
||||
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
|
||||
ctx.inner(),
|
||||
elem_type.inner(),
|
||||
elem_type.inner(),
|
||||
self.accumulation.asType(ctx).inner(),
|
||||
ctx._inner,
|
||||
elem_type._inner,
|
||||
elem_type._inner,
|
||||
self.accumulation.asType(ctx)._inner,
|
||||
self.component_count,
|
||||
self.component_count,
|
||||
self.num_primitive_operations,
|
||||
@ -197,11 +196,11 @@ pub fn dot_general(
|
||||
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
|
||||
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
|
||||
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
||||
}).as(mlir.Attribute),
|
||||
}).asAttr(),
|
||||
},
|
||||
.{ "precision_config", .array(ctx, &precisions) },
|
||||
// keep algorithm as the last attribute so we can omit it when it's not set.
|
||||
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined },
|
||||
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType().as(mlir.RankedTensorType).?) orelse undefined },
|
||||
};
|
||||
const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1;
|
||||
return mlir.Operation.make(ctx, "stablehlo.dot_general", .{
|
||||
@ -214,19 +213,15 @@ pub fn dot_general(
|
||||
|
||||
pub fn constant(
|
||||
ctx: mlir.Context,
|
||||
result_type: mlir.RankedTensorType,
|
||||
dims: []const i64,
|
||||
elem_type: mlir.DenseElementsAttributeTypes,
|
||||
raw_bytes: []const u8,
|
||||
location: mlir.Location,
|
||||
) mlir.Operation {
|
||||
const attribute = switch (elem_type) {
|
||||
inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute),
|
||||
};
|
||||
|
||||
return mlir.Operation.make(ctx, "stablehlo.constant", .{
|
||||
.operands = &.{},
|
||||
.results = &.{result_type.as(mlir.Type)},
|
||||
.attributes = &.{.{ "value", attribute }},
|
||||
.results = &.{.tensor(dims, elem_type.mlirType(ctx))},
|
||||
.attributes = &.{.{ "value", .denseElementsFromBytes(ctx, dims, elem_type, raw_bytes) }},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
@ -285,10 +280,10 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
|
||||
});
|
||||
}
|
||||
|
||||
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation {
|
||||
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "stablehlo.reshape", .{
|
||||
.operands = &.{value},
|
||||
.results = &.{result_type.as(mlir.Type)},
|
||||
.results = &.{result_type},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
@ -332,7 +327,7 @@ pub fn gather(
|
||||
args.start_indices_batching_dims,
|
||||
args.start_index_map,
|
||||
args.index_vector_dim,
|
||||
).as(mlir.Attribute) },
|
||||
).asAttr() },
|
||||
.{ "slice_sizes", .dense(ctx, .i64, slice_sizes) },
|
||||
.{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) },
|
||||
},
|
||||
@ -358,22 +353,20 @@ pub const ScatterArgs = struct {
|
||||
unique_indices: bool = false,
|
||||
|
||||
pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute {
|
||||
return mlir.Attribute.wrap(
|
||||
c.stablehloScatterDimensionNumbersGet(
|
||||
ctx.inner(),
|
||||
@intCast(self.update_window_dims.len),
|
||||
self.update_window_dims.ptr,
|
||||
@intCast(self.inserted_window_dims.len),
|
||||
self.inserted_window_dims.ptr,
|
||||
@intCast(self.input_batching_dims.len),
|
||||
self.input_batching_dims.ptr,
|
||||
@intCast(self.scatter_indices_batching_dims.len),
|
||||
self.scatter_indices_batching_dims.ptr,
|
||||
@intCast(self.scatter_dims_to_operand_dims.len),
|
||||
self.scatter_dims_to_operand_dims.ptr,
|
||||
self.index_vector_dim,
|
||||
),
|
||||
);
|
||||
return .{ ._inner = c.stablehloScatterDimensionNumbersGet(
|
||||
ctx._inner,
|
||||
@intCast(self.update_window_dims.len),
|
||||
self.update_window_dims.ptr,
|
||||
@intCast(self.inserted_window_dims.len),
|
||||
self.inserted_window_dims.ptr,
|
||||
@intCast(self.input_batching_dims.len),
|
||||
self.input_batching_dims.ptr,
|
||||
@intCast(self.scatter_indices_batching_dims.len),
|
||||
self.scatter_indices_batching_dims.ptr,
|
||||
@intCast(self.scatter_dims_to_operand_dims.len),
|
||||
self.scatter_dims_to_operand_dims.ptr,
|
||||
self.index_vector_dim,
|
||||
) };
|
||||
}
|
||||
};
|
||||
|
||||
@ -431,8 +424,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d
|
||||
.operands = &.{ lhs, rhs },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "comparison_direction", comparison_direction.as(mlir.Attribute) },
|
||||
.{ "compare_type", compare_type.as(mlir.Attribute) },
|
||||
.{ "comparison_direction", comparison_direction.asAttr() },
|
||||
.{ "compare_type", compare_type.asAttr() },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -580,7 +573,7 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
|
||||
.{ "left_side", .i1FromBool(ctx, opts.left_side) },
|
||||
.{ "lower", .i1FromBool(ctx, opts.lower) },
|
||||
.{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) },
|
||||
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) },
|
||||
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).asAttr() },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -596,7 +589,7 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
|
||||
.operands = &.{value},
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) },
|
||||
.{ "fft_type", FftType.init(ctx, opts.kind).asAttr() },
|
||||
.{ "fft_length", .dense(ctx, .i64, opts.length) },
|
||||
},
|
||||
.location = location,
|
||||
@ -608,7 +601,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r
|
||||
.operands = &.{ a, b, shape },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) },
|
||||
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).asAttr() },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -619,7 +612,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in
|
||||
.operands = &.{initial_state},
|
||||
.results = &.{ res_state_type, res_type },
|
||||
.attributes = &.{
|
||||
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) },
|
||||
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).asAttr() },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -695,7 +688,7 @@ pub fn convolution(
|
||||
) mlir.Operation {
|
||||
var max_precisions: [2]mlir.Attribute = undefined;
|
||||
for (opts.precision_config, 0..) |p, i| {
|
||||
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute);
|
||||
max_precisions[i] = PrecisionAttribute.init(ctx, p).asAttr();
|
||||
}
|
||||
var window_reversal: [3]i32 = undefined;
|
||||
for (opts.window_reversal, 0..) |w, i| {
|
||||
@ -721,7 +714,7 @@ pub fn convolution(
|
||||
.output_batch_dimension = opts.output_batch_dimension,
|
||||
.output_feature_dimension = opts.output_feature_dimension,
|
||||
.output_spatial_dimensions = opts.output_spatial_dimensions,
|
||||
}).as(mlir.Attribute),
|
||||
}).asAttr(),
|
||||
},
|
||||
.{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) },
|
||||
.{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) },
|
||||
@ -756,13 +749,13 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, "");
|
||||
if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
|
||||
stdx.debug.assert(
|
||||
backend_config.is_a(mlir.StringAttribute),
|
||||
backend_config.isA(mlir.StringAttribute),
|
||||
"API version < 4 requires a string as backend_config, got {}",
|
||||
.{backend_config},
|
||||
);
|
||||
} else {
|
||||
stdx.debug.assert(
|
||||
backend_config.is_a(mlir.DictionaryAttribute),
|
||||
backend_config.isA(mlir.DictionaryAttribute),
|
||||
"API version >= 4 requires a dictionary as backend_config, got {}",
|
||||
.{backend_config},
|
||||
);
|
||||
@ -780,7 +773,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (opts.output_operand_aliases) |alias| {
|
||||
output_operand_aliases.appendAssumeCapacity(
|
||||
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute),
|
||||
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(),
|
||||
);
|
||||
}
|
||||
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
|
||||
@ -805,7 +798,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
const operand_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||
for (inputs) |input| {
|
||||
const ranked_type = input.getType().as(mlir.RankedTensorType);
|
||||
const ranked_type = input.getType().as(mlir.RankedTensorType).?;
|
||||
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
|
||||
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
||||
}
|
||||
@ -824,7 +817,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
const result_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (res_types) |t| {
|
||||
const ranked_t = t.as(mlir.RankedTensorType);
|
||||
const ranked_t = t.as(mlir.RankedTensorType).?;
|
||||
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
|
||||
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
||||
}
|
||||
@ -846,13 +839,10 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
pub const DotDimensionNumbersAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(DotDimensionNumbersAttribute, .{
|
||||
.is_a_fn = c.stablehloAttributeIsADotDimensionNumbers,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsADotDimensionNumbers;
|
||||
const Self = DotDimensionNumbersAttribute;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub fn init(ctx: mlir.Context, args: struct {
|
||||
lhs_batching_dimensions: []const i64,
|
||||
@ -860,9 +850,9 @@ pub const DotDimensionNumbersAttribute = struct {
|
||||
lhs_contracting_dimensions: []const i64,
|
||||
rhs_contracting_dimensions: []const i64,
|
||||
}) Self {
|
||||
return Self.wrap(
|
||||
c.stablehloDotDimensionNumbersGet(
|
||||
ctx.inner(),
|
||||
return .{
|
||||
._inner = c.stablehloDotDimensionNumbersGet(
|
||||
ctx._inner,
|
||||
@intCast(args.lhs_batching_dimensions.len),
|
||||
args.lhs_batching_dimensions.ptr,
|
||||
@intCast(args.rhs_batching_dimensions.len),
|
||||
@ -872,52 +862,49 @@ pub const DotDimensionNumbersAttribute = struct {
|
||||
@intCast(args.rhs_contracting_dimensions.len),
|
||||
args.rhs_contracting_dimensions.ptr,
|
||||
),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
pub fn getLhsBatchingDimensionsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self.inner()));
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getLhsBatchingDimensionsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getRhsBatchingDimensionsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self.inner()));
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getRhsBatchingDimensionsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getLhsContractingDimensionsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self.inner()));
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getLhsContractingDimensionsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getRhsContractingDimensionsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self.inner()));
|
||||
return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getRhsContractingDimensionsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
};
|
||||
|
||||
pub const GatherDimensionNumbersAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(GatherDimensionNumbersAttribute, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers;
|
||||
const Self = GatherDimensionNumbersAttribute;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub fn init(
|
||||
ctx: mlir.Context,
|
||||
@ -928,9 +915,9 @@ pub const GatherDimensionNumbersAttribute = struct {
|
||||
start_index_map: []const i64,
|
||||
index_vector_dim: i64,
|
||||
) Self {
|
||||
return Self.wrap(
|
||||
c.stablehloGatherDimensionNumbersGet(
|
||||
ctx.inner(),
|
||||
return .{
|
||||
._inner = c.stablehloGatherDimensionNumbersGet(
|
||||
ctx._inner,
|
||||
@intCast(offset_dims.len),
|
||||
offset_dims.ptr,
|
||||
@intCast(collapsed_slice_dims.len),
|
||||
@ -943,64 +930,61 @@ pub const GatherDimensionNumbersAttribute = struct {
|
||||
start_index_map.ptr,
|
||||
index_vector_dim,
|
||||
),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
pub fn getOffsetDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self.inner()));
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getOffsetDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getCollapsedSliceDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self.inner()));
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getCollapsedSliceDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getStartIndexMapSize(self: Self) usize {
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self.inner()));
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getOperandBatchingDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self.inner()));
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getOperandBatchingDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getStartIndicesBatchingDimsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self.inner()));
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getStartIndicesBatchingDimsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getStartIndexMapElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getIndexVectorDim(self: Self) usize {
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self.inner()));
|
||||
return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self._inner));
|
||||
}
|
||||
};
|
||||
|
||||
pub const ConvDimensionNumbersAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(ConvDimensionNumbersAttribute, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers;
|
||||
const Self = ConvDimensionNumbersAttribute;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub fn init(ctx: mlir.Context, args: struct {
|
||||
input_batch_dimension: i64,
|
||||
@ -1013,9 +997,9 @@ pub const ConvDimensionNumbersAttribute = struct {
|
||||
output_feature_dimension: i64,
|
||||
output_spatial_dimensions: []const i64,
|
||||
}) Self {
|
||||
return Self.wrap(
|
||||
c.stablehloConvDimensionNumbersGet(
|
||||
ctx.inner(),
|
||||
return .{
|
||||
._inner = c.stablehloConvDimensionNumbersGet(
|
||||
ctx._inner,
|
||||
args.input_batch_dimension,
|
||||
args.input_feature_dimension,
|
||||
@intCast(args.input_spatial_dimensions.len),
|
||||
@ -1029,67 +1013,64 @@ pub const ConvDimensionNumbersAttribute = struct {
|
||||
@intCast(args.output_spatial_dimensions.len),
|
||||
args.output_spatial_dimensions.ptr,
|
||||
),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
pub fn getInputBatchDimension(self: Self) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetInputBatchDimension(self.inner());
|
||||
return c.stablehloConvDimensionNumbersGetInputBatchDimension(self._inner);
|
||||
}
|
||||
|
||||
pub fn getInputFeatureDimension(self: Self) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self.inner());
|
||||
return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self._inner);
|
||||
}
|
||||
|
||||
pub fn getInputSpatialDimensionsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self.inner()));
|
||||
return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getInputSpatialDimensionsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getKernelInputFeatureDimension(self: Self) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self.inner());
|
||||
return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self._inner);
|
||||
}
|
||||
|
||||
pub fn getKernelOutputFeatureDimension(self: Self) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self.inner());
|
||||
return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self._inner);
|
||||
}
|
||||
|
||||
pub fn getKernelSpatialDimensionsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self.inner()));
|
||||
return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getKernelSpatialDimensionsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
|
||||
pub fn getOutputBatchDimension(self: Self) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self.inner());
|
||||
return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self._inner);
|
||||
}
|
||||
|
||||
pub fn getOutputFeatureDimension(self: Self) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self.inner());
|
||||
return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self._inner);
|
||||
}
|
||||
|
||||
pub fn getOutputSpatialDimensionsSize(self: Self) usize {
|
||||
return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self.inner()));
|
||||
return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self._inner));
|
||||
}
|
||||
|
||||
pub fn getOutputSpatialDimensionsElem(self: Self, pos: usize) i64 {
|
||||
return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self.inner(), @intCast(pos));
|
||||
return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self._inner, @intCast(pos));
|
||||
}
|
||||
};
|
||||
|
||||
pub const OutputOperandAliasAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(OutputOperandAliasAttribute, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAOutputOperandAlias,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsAOutputOperandAlias;
|
||||
pub const asAttr = mlir.Attribute.fromAny(OutputOperandAliasAttribute);
|
||||
pub const eql = mlir.Attribute.eqlAny(OutputOperandAliasAttribute);
|
||||
|
||||
pub fn init(
|
||||
ctx: mlir.Context,
|
||||
@ -1097,27 +1078,24 @@ pub const OutputOperandAliasAttribute = struct {
|
||||
operand_index: i64,
|
||||
operand_tuple_indices: []const i64,
|
||||
) OutputOperandAliasAttribute {
|
||||
return OutputOperandAliasAttribute.wrap(c.stablehloOutputOperandAliasGet(
|
||||
ctx.inner(),
|
||||
return .{ ._inner = c.stablehloOutputOperandAliasGet(
|
||||
ctx._inner,
|
||||
@intCast(output_tuple_indices.len),
|
||||
output_tuple_indices.ptr,
|
||||
@intCast(operand_index),
|
||||
@intCast(operand_tuple_indices.len),
|
||||
operand_tuple_indices.ptr,
|
||||
));
|
||||
) };
|
||||
}
|
||||
};
|
||||
|
||||
pub const PrecisionAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(PrecisionAttribute, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAPrecisionAttr,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsAPrecisionAttr;
|
||||
const Self = PrecisionAttribute;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub const Precision = enum {
|
||||
DEFAULT,
|
||||
@ -1126,11 +1104,11 @@ pub const PrecisionAttribute = struct {
|
||||
};
|
||||
|
||||
pub fn init(ctx: mlir.Context, value: Precision) Self {
|
||||
return Self.wrap(c.stablehloPrecisionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
|
||||
return .{ ._inner = c.stablehloPrecisionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
|
||||
}
|
||||
|
||||
pub fn getValue(self: Self) Precision {
|
||||
const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self.inner()));
|
||||
const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self._inner));
|
||||
return std.meta.stringToEnum(Precision, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
@ -1138,13 +1116,10 @@ pub const PrecisionAttribute = struct {
|
||||
pub const ComparisonDirection = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(ComparisonDirection, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr;
|
||||
const Self = ComparisonDirection;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub const Direction = enum {
|
||||
EQ,
|
||||
@ -1156,11 +1131,11 @@ pub const ComparisonDirection = struct {
|
||||
};
|
||||
|
||||
pub fn init(ctx: mlir.Context, value: Direction) Self {
|
||||
return Self.wrap(c.stablehloComparisonDirectionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
|
||||
return .{ ._inner = c.stablehloComparisonDirectionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
|
||||
}
|
||||
|
||||
pub fn getValue(self: Self) Direction {
|
||||
const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self.inner()));
|
||||
const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self._inner));
|
||||
return std.meta.stringToEnum(Direction, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
@ -1168,13 +1143,10 @@ pub const ComparisonDirection = struct {
|
||||
pub const CompareType = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(CompareType, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr;
|
||||
const Self = CompareType;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub const Type = enum {
|
||||
SIGNED,
|
||||
@ -1184,11 +1156,11 @@ pub const CompareType = struct {
|
||||
};
|
||||
|
||||
pub fn init(ctx: mlir.Context, value: Type) Self {
|
||||
return Self.wrap(c.stablehloComparisonTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
|
||||
return .{ ._inner = c.stablehloComparisonTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
|
||||
}
|
||||
|
||||
pub fn getValue(self: Self) Type {
|
||||
const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self.inner()));
|
||||
const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self._inner));
|
||||
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
@ -1196,13 +1168,10 @@ pub const CompareType = struct {
|
||||
pub const Transpose = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(Transpose, .{
|
||||
.is_a_fn = c.stablehloAttributeIsATransposeAttr,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsATransposeAttr;
|
||||
const Self = Transpose;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub const Type = enum {
|
||||
NO_TRANSPOSE,
|
||||
@ -1211,11 +1180,11 @@ pub const Transpose = struct {
|
||||
};
|
||||
|
||||
pub fn init(ctx: mlir.Context, value: Type) Self {
|
||||
return Self.wrap(c.stablehloTransposeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
|
||||
return .{ ._inner = c.stablehloTransposeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
|
||||
}
|
||||
|
||||
pub fn getValue(self: Self) Type {
|
||||
const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self.inner()));
|
||||
const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self._inner));
|
||||
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
@ -1223,13 +1192,10 @@ pub const Transpose = struct {
|
||||
pub const FftType = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(FftType, .{
|
||||
.is_a_fn = c.stablehloAttributeIsAFftTypeAttr,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsAFftTypeAttr;
|
||||
const Self = FftType;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub const Type = enum {
|
||||
FFT,
|
||||
@ -1239,11 +1205,11 @@ pub const FftType = struct {
|
||||
};
|
||||
|
||||
pub fn init(ctx: mlir.Context, value: Type) Self {
|
||||
return Self.wrap(c.stablehloFftTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
|
||||
return .{ ._inner = c.stablehloFftTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
|
||||
}
|
||||
|
||||
pub fn getValue(self: Self) Type {
|
||||
const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self.inner()));
|
||||
const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self._inner));
|
||||
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
@ -1251,13 +1217,10 @@ pub const FftType = struct {
|
||||
pub const RngDistribution = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(RngDistribution, .{
|
||||
.is_a_fn = c.stablehloAttributeIsARngDistributionAttr,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsARngDistributionAttr;
|
||||
const Self = RngDistribution;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub const Type = enum {
|
||||
UNIFORM,
|
||||
@ -1265,11 +1228,11 @@ pub const RngDistribution = struct {
|
||||
};
|
||||
|
||||
pub fn init(ctx: mlir.Context, value: Type) Self {
|
||||
return Self.wrap(c.stablehloRngDistributionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
|
||||
return .{ ._inner = c.stablehloRngDistributionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
|
||||
}
|
||||
|
||||
pub fn getValue(self: Self) Type {
|
||||
const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self.inner()));
|
||||
const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self._inner));
|
||||
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
@ -1277,13 +1240,10 @@ pub const RngDistribution = struct {
|
||||
pub const RngAlgorithm = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
pub usingnamespace mlir.MlirHelpers(RngAlgorithm, .{
|
||||
.is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
pub const is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr;
|
||||
const Self = RngAlgorithm;
|
||||
pub const asAttr = mlir.Attribute.fromAny(Self);
|
||||
pub const eql = mlir.Attribute.eqlAny(Self);
|
||||
|
||||
pub const Type = enum {
|
||||
DEFAULT,
|
||||
@ -1292,11 +1252,11 @@ pub const RngAlgorithm = struct {
|
||||
};
|
||||
|
||||
pub fn init(ctx: mlir.Context, value: Type) Self {
|
||||
return Self.wrap(c.stablehloRngAlgorithmAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
|
||||
return .{ ._inner = c.stablehloRngAlgorithmAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
|
||||
}
|
||||
|
||||
pub fn getValue(self: Self) Type {
|
||||
const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self.inner()));
|
||||
const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self._inner));
|
||||
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
|
||||
1045
mlir/mlir.zig
1045
mlir/mlir.zig
File diff suppressed because it is too large
Load Diff
176
zml/mlir.zig
176
zml/mlir.zig
@ -1,176 +0,0 @@
|
||||
const mlir = @This();
|
||||
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const dtype = @import("dtype.zig");
|
||||
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
|
||||
const log = std.log.scoped(.@"zml/mlir");
|
||||
|
||||
pub usingnamespace @import("mlir");
|
||||
|
||||
pub const ext = struct {
|
||||
pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type {
|
||||
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type);
|
||||
}
|
||||
|
||||
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
|
||||
return switch (dt) {
|
||||
.bool => .bool,
|
||||
.i8 => .i8,
|
||||
.i16 => .i16,
|
||||
.i32 => .i32,
|
||||
.i64 => .i64,
|
||||
.u8 => .u8,
|
||||
.u16 => .u16,
|
||||
.u32 => .u32,
|
||||
.u64 => .u64,
|
||||
.bf16 => .bf16,
|
||||
.f16 => .f16,
|
||||
.f32 => .f32,
|
||||
.f64 => .f64,
|
||||
else => null,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute {
|
||||
const ranked_type_ = ranked_type.as(mlir.Type);
|
||||
return switch (dt) {
|
||||
.bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
.f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
|
||||
};
|
||||
}
|
||||
|
||||
pub const RankedTensorType = struct {
|
||||
pub fn fromShape(ctx: mlir.Context, sh: Shape) mlir.RankedTensorType {
|
||||
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype()));
|
||||
}
|
||||
};
|
||||
|
||||
pub const Type = struct {
|
||||
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
|
||||
return switch (dt) {
|
||||
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type),
|
||||
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type),
|
||||
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type),
|
||||
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type),
|
||||
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type),
|
||||
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type),
|
||||
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type),
|
||||
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type),
|
||||
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type),
|
||||
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type),
|
||||
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type),
|
||||
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type),
|
||||
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type),
|
||||
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type),
|
||||
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type),
|
||||
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type),
|
||||
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type),
|
||||
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type),
|
||||
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type),
|
||||
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type),
|
||||
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type),
|
||||
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
|
||||
const mapping = .{
|
||||
.{ .bool, mlir.IntegerType(.i1) },
|
||||
|
||||
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
|
||||
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
|
||||
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
|
||||
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
|
||||
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
|
||||
.{ .bf16, mlir.FloatType(.bf16) },
|
||||
.{ .f16, mlir.FloatType(.f16) },
|
||||
.{ .f32, mlir.FloatType(.f32) },
|
||||
.{ .f64, mlir.FloatType(.f64) },
|
||||
|
||||
.{ .i4, mlir.IntegerType(.i4) },
|
||||
.{ .i8, mlir.IntegerType(.i8) },
|
||||
.{ .i16, mlir.IntegerType(.i16) },
|
||||
.{ .i32, mlir.IntegerType(.i32) },
|
||||
.{ .i64, mlir.IntegerType(.i64) },
|
||||
|
||||
.{ .u4, mlir.IntegerType(.u4) },
|
||||
.{ .u8, mlir.IntegerType(.u8) },
|
||||
.{ .u16, mlir.IntegerType(.u16) },
|
||||
.{ .u32, mlir.IntegerType(.u32) },
|
||||
.{ .u64, mlir.IntegerType(.u64) },
|
||||
|
||||
.{ .c64, mlir.ComplexType(.c64) },
|
||||
.{ .c128, mlir.ComplexType(.c128) },
|
||||
};
|
||||
|
||||
inline for (mapping) |entry| {
|
||||
const dt, const mlirT = entry;
|
||||
if (mlir_type.is_a(mlirT)) {
|
||||
return dt;
|
||||
}
|
||||
}
|
||||
|
||||
stdx.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
|
||||
}
|
||||
};
|
||||
|
||||
pub const Attribute = struct {
|
||||
pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute {
|
||||
switch (data) {
|
||||
.bool => |val| {
|
||||
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute);
|
||||
},
|
||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| {
|
||||
const float_type = @field(mlir.FloatTypes, @tagName(tag));
|
||||
const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32());
|
||||
return float_attr.as(mlir.Attribute);
|
||||
},
|
||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| {
|
||||
const int_type = @field(mlir.IntegerTypes, @tagName(tag));
|
||||
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
|
||||
return int_attr.as(mlir.Attribute);
|
||||
},
|
||||
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
pub const DenseElementsAttribute = struct {
|
||||
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
|
||||
return switch (data.dtype()) {
|
||||
.bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
.f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
||||
};
|
||||
}
|
||||
};
|
||||
};
|
||||
101
zml/mlirx.zig
Normal file
101
zml/mlirx.zig
Normal file
@ -0,0 +1,101 @@
|
||||
const std = @import("std");
|
||||
|
||||
const mlir = @import("mlir");
|
||||
|
||||
const dtype = @import("dtype.zig");
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
|
||||
const mlirx = @This();
|
||||
|
||||
/// Returns the mlir.Type corresponding to a given zml.Shape.
|
||||
pub fn tensorType(ctx: mlir.Context, sh: Shape) mlir.Type {
|
||||
return .tensor(sh.dims(), mlirx.Type.fromDType(ctx, sh.dtype()));
|
||||
}
|
||||
|
||||
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
|
||||
return switch (dt) {
|
||||
.bool => .bool,
|
||||
.i8 => .i8,
|
||||
.i16 => .i16,
|
||||
.i32 => .i32,
|
||||
.i64 => .i64,
|
||||
.u8 => .u8,
|
||||
.u16 => .u16,
|
||||
.u32 => .u32,
|
||||
.u64 => .u64,
|
||||
.bf16 => .bf16,
|
||||
.f16 => .f16,
|
||||
.f32 => .f32,
|
||||
.f64 => .f64,
|
||||
else => null,
|
||||
};
|
||||
}
|
||||
|
||||
pub const Type = struct {
|
||||
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
|
||||
return switch (dt) {
|
||||
.bool => .int(ctx, .i1),
|
||||
.f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz),
|
||||
.f8e4m3fn => .float(ctx, .f8e4m3fn),
|
||||
.f8e4m3fnuz => .float(ctx, .f8e4m3fnuz),
|
||||
.f8e5m2 => .float(ctx, .f8e5m2),
|
||||
.f8e5m2fnuz => .float(ctx, .f8e5m2fnuz),
|
||||
.bf16 => .float(ctx, .bf16),
|
||||
.f16 => .float(ctx, .f16),
|
||||
.f32 => .float(ctx, .f32),
|
||||
.f64 => .float(ctx, .f64),
|
||||
.i4 => .int(ctx, .i4),
|
||||
.i8 => .int(ctx, .i8),
|
||||
.i16 => .int(ctx, .i16),
|
||||
.i32 => .int(ctx, .i32),
|
||||
.i64 => .int(ctx, .i64),
|
||||
.u4 => .int(ctx, .u4),
|
||||
.u8 => .int(ctx, .u8),
|
||||
.u16 => .int(ctx, .u16),
|
||||
.u32 => .int(ctx, .u32),
|
||||
.u64 => .int(ctx, .u64),
|
||||
.c64 => .complex(ctx, .c64),
|
||||
.c128 => .complex(ctx, .c128),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
|
||||
const mapping = .{
|
||||
.{ .bool, mlir.IntegerType(.i1) },
|
||||
|
||||
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
|
||||
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
|
||||
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
|
||||
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
|
||||
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
|
||||
.{ .bf16, mlir.FloatType(.bf16) },
|
||||
.{ .f16, mlir.FloatType(.f16) },
|
||||
.{ .f32, mlir.FloatType(.f32) },
|
||||
.{ .f64, mlir.FloatType(.f64) },
|
||||
|
||||
.{ .i4, mlir.IntegerType(.i4) },
|
||||
.{ .i8, mlir.IntegerType(.i8) },
|
||||
.{ .i16, mlir.IntegerType(.i16) },
|
||||
.{ .i32, mlir.IntegerType(.i32) },
|
||||
.{ .i64, mlir.IntegerType(.i64) },
|
||||
|
||||
.{ .u4, mlir.IntegerType(.u4) },
|
||||
.{ .u8, mlir.IntegerType(.u8) },
|
||||
.{ .u16, mlir.IntegerType(.u16) },
|
||||
.{ .u32, mlir.IntegerType(.u32) },
|
||||
.{ .u64, mlir.IntegerType(.u64) },
|
||||
|
||||
.{ .c64, mlir.ComplexType(.c64) },
|
||||
.{ .c128, mlir.ComplexType(.c128) },
|
||||
};
|
||||
|
||||
inline for (mapping) |entry| {
|
||||
const dt, const mlirT = entry;
|
||||
if (mlirT.is_a_fn(mlir_type._inner)) {
|
||||
return dt;
|
||||
}
|
||||
}
|
||||
|
||||
std.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
|
||||
}
|
||||
};
|
||||
@ -2,21 +2,18 @@ const std = @import("std");
|
||||
|
||||
const asynk = @import("async");
|
||||
const dialect = @import("mlir/dialects");
|
||||
const runfiles = @import("runfiles");
|
||||
const mlir = @import("mlir");
|
||||
const stdx = @import("stdx");
|
||||
const xla_pb = @import("//xla:xla_proto");
|
||||
|
||||
const BaseExe = @import("exe.zig").BaseExe;
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const Bufferized = @import("tensor.zig").Bufferized;
|
||||
const meta = @import("meta.zig");
|
||||
const mlir = @import("mlir.zig");
|
||||
const Location = mlir.Location;
|
||||
const mlirx = @import("mlirx.zig");
|
||||
const ops = @import("ops.zig");
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
const Target = @import("platform.zig").Target;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
const Tracer = @import("tools/tracer.zig").Tracer;
|
||||
@ -170,8 +167,8 @@ pub const CompilationContext = struct {
|
||||
|
||||
const sharding = self._platform.sharding();
|
||||
const mlir_ctx = self._mlir_ctx;
|
||||
module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr());
|
||||
module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
|
||||
module.op().setAttributeByName("mhlo.num_replicas", .int(mlir_ctx, .i32, sharding.num_replicas));
|
||||
module.op().setAttributeByName("mhlo.num_partitions", .int(mlir_ctx, .i32, sharding.num_partitions));
|
||||
|
||||
const module_hash = computeModuleHash(self._platform, module);
|
||||
var module_dir: ?[]const u8 = null;
|
||||
@ -346,7 +343,7 @@ pub const CompilationContext = struct {
|
||||
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
||||
|
||||
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
||||
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
|
||||
for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh);
|
||||
|
||||
const og_block_args = self._block_args;
|
||||
defer {
|
||||
@ -947,7 +944,7 @@ pub fn fillMlirTypes(v: anytype, mlir_ctx: mlir.Context, types: []mlir.Type) voi
|
||||
var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types };
|
||||
meta.visit((struct {
|
||||
fn cb(inner_context: *LocalContext, tensor: *const Tensor) void {
|
||||
inner_context.types[inner_context.index] = mlir.ext.mlirType(inner_context.mlir_ctx, tensor.shape());
|
||||
inner_context.types[inner_context.index] = mlirx.tensorType(inner_context.mlir_ctx, tensor.shape());
|
||||
inner_context.index += 1;
|
||||
}
|
||||
}).cb, &context, v);
|
||||
|
||||
@ -5,7 +5,7 @@ const dialect = @import("mlir/dialects");
|
||||
const Context = @import("../context.zig").Context;
|
||||
const DataType = @import("../dtype.zig").DataType;
|
||||
const Data = @import("../dtype.zig").Data;
|
||||
const mlir = @import("../mlir.zig");
|
||||
const mlirx = @import("../mlirx.zig");
|
||||
const module = @import("../module.zig");
|
||||
const CompilationContext = module.CompilationContext;
|
||||
const SdpaOpts = @import("../nn.zig").SdpaOpts;
|
||||
@ -130,7 +130,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
.api_version = .original,
|
||||
},
|
||||
&.{
|
||||
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
||||
mlirx.tensorType(mlir_ctx, q.shape()),
|
||||
.tensor(&.{0}, .int(mlir_ctx, .u8)),
|
||||
},
|
||||
loc,
|
||||
|
||||
47
zml/ops.zig
47
zml/ops.zig
@ -1,24 +1,16 @@
|
||||
const std = @import("std");
|
||||
const assert = std.debug.assert;
|
||||
|
||||
const mlir = @import("mlir");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const _collectAxes = @import("tensor.zig")._collectAxes;
|
||||
const buffer = @import("buffer.zig");
|
||||
const Buffer = buffer.Buffer;
|
||||
const Bufferized = @import("tensor.zig").Bufferized;
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const CompilationContext = @import("module.zig").CompilationContext;
|
||||
const Context = @import("context.zig").Context;
|
||||
const Data = @import("dtype.zig").Data;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const helpers = @import("helpers.zig");
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const meta = @import("meta.zig");
|
||||
const mlir = @import("mlir.zig");
|
||||
const module = @import("module.zig");
|
||||
const CompilationContext = module.CompilationContext;
|
||||
const mlirx = @import("mlirx.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
|
||||
const EnumLiteral = @TypeOf(.enum_literal);
|
||||
@ -200,14 +192,14 @@ pub fn reduce(
|
||||
mlir_ctx,
|
||||
val,
|
||||
inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced],
|
||||
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type),
|
||||
mlirx.tensorType(mlir_ctx, reduced_shape),
|
||||
inner_ctx.loc,
|
||||
);
|
||||
tensor.* = Tensor._result(reduced_shape, broad_val.result(0));
|
||||
inner_ctx.index += 1;
|
||||
}
|
||||
}).cb, &local_context, &res);
|
||||
assert(local_context.index == op.numResults());
|
||||
std.debug.assert(local_context.index == op.numResults());
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -248,7 +240,8 @@ pub fn reduceWindow(
|
||||
.{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) },
|
||||
.{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) },
|
||||
.{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) },
|
||||
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, opts.padding) },
|
||||
// Cast the [][2]i64 to []i64 (safe)
|
||||
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, @ptrCast(opts.padding)) },
|
||||
},
|
||||
.location = loc,
|
||||
});
|
||||
@ -609,8 +602,8 @@ pub fn sort(
|
||||
.result_type_inference = true,
|
||||
.blocks = &.{block},
|
||||
.attributes = &.{
|
||||
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) },
|
||||
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) },
|
||||
.{ "dimension", .int(ctx.mlirCtx(), .i64, dimension) },
|
||||
.{ "is_stable", .boolean(ctx.mlirCtx(), is_stable) },
|
||||
},
|
||||
.location = loc,
|
||||
});
|
||||
@ -767,7 +760,7 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
|
||||
inner_ctx.index += 1;
|
||||
}
|
||||
}).cb, &context, &res);
|
||||
assert(context.index == op.numResults());
|
||||
std.debug.assert(context.index == op.numResults());
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -817,7 +810,7 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
|
||||
|
||||
var res_types: [outputs.len]mlir.Type = undefined;
|
||||
inline for (outputs, 0..) |output, i| {
|
||||
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
||||
res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output);
|
||||
}
|
||||
|
||||
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{
|
||||
@ -1031,7 +1024,7 @@ pub fn scatter(
|
||||
inner_ctx.index += 1;
|
||||
}
|
||||
}).cb, &local_context, &res);
|
||||
assert(local_context.index == op.numResults());
|
||||
std.debug.assert(local_context.index == op.numResults());
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -1327,30 +1320,30 @@ pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype,
|
||||
}
|
||||
|
||||
fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor {
|
||||
const ctx = module.CompilationContext.current();
|
||||
const ctx = CompilationContext.current();
|
||||
|
||||
const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable;
|
||||
ctx.extractValues(inputs, values);
|
||||
|
||||
const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable;
|
||||
for (outputs, 0..) |output, i| {
|
||||
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
||||
res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output);
|
||||
}
|
||||
|
||||
const metadata_type_info = @typeInfo(@TypeOf(metadata));
|
||||
var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined;
|
||||
inline for (metadata_type_info.@"struct".fields, 0..) |field, i| {
|
||||
const attribute: mlir.Attribute = switch (@typeInfo(field.type)) {
|
||||
.int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
|
||||
.int, .comptime_int => .int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
|
||||
else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)),
|
||||
};
|
||||
metadata_attributes_tuple[i] = .{ field.name, attribute };
|
||||
}
|
||||
|
||||
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{
|
||||
.{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
|
||||
.{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
|
||||
} ++ metadata_attributes_tuple));
|
||||
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(metadata_attributes_tuple ++ [_]mlir.AttrTuple{
|
||||
.{ "pjrt_api", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
|
||||
.{ "pjrt_client", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
|
||||
}));
|
||||
|
||||
const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable;
|
||||
for (inputs, 0..) |input, i| {
|
||||
|
||||
104
zml/tensor.zig
104
zml/tensor.zig
@ -1,20 +1,17 @@
|
||||
const std = @import("std");
|
||||
const assert = std.debug.assert;
|
||||
const testing = std.testing;
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const mlir = @import("mlir");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const CompilationContext = @import("module.zig").CompilationContext;
|
||||
const Data = @import("dtype.zig").Data;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const Memory = @import("buffer.zig").Buffer.Memory;
|
||||
const meta = @import("meta.zig");
|
||||
const mlir = @import("mlir.zig");
|
||||
const Location = mlir.Location;
|
||||
const module = @import("module.zig");
|
||||
const CompilationContext = module.CompilationContext;
|
||||
const mlirx = @import("mlirx.zig");
|
||||
const ops = @import("ops.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
@ -112,12 +109,12 @@ pub const Tensor = struct {
|
||||
///
|
||||
/// The shape is derived from the type of the mlir.Value.
|
||||
pub fn fromMlirValue(val: mlir.Value) Tensor {
|
||||
const ranked_tensor = val.getType().as(mlir.RankedTensorType);
|
||||
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
|
||||
const n = ranked_tensor.getRank();
|
||||
|
||||
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
|
||||
|
||||
var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) };
|
||||
var sh: Shape = .{ ._dtype = mlirx.Type.toDType(ranked_tensor.getElementType()) };
|
||||
for (0..n) |i| {
|
||||
sh._dims.appendAssumeCapacity(ranked_tensor.getDimension(i));
|
||||
}
|
||||
@ -322,7 +319,7 @@ pub const Tensor = struct {
|
||||
const op = dialect.stablehlo.bitcast_convert(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type),
|
||||
mlirx.tensorType(self.getContext().mlirCtx(), res_shape),
|
||||
loc,
|
||||
);
|
||||
|
||||
@ -559,8 +556,8 @@ pub const Tensor = struct {
|
||||
ctx.mlirCtx(),
|
||||
self.algorithm,
|
||||
self._state.value(),
|
||||
mlir.ext.mlirType(ctx.mlirCtx(), self._state._shape),
|
||||
mlir.ext.mlirType(ctx.mlirCtx(), sh),
|
||||
mlirx.tensorType(ctx.mlirCtx(), self._state._shape),
|
||||
mlirx.tensorType(ctx.mlirCtx(), sh),
|
||||
loc,
|
||||
);
|
||||
return .{ self.update(op.result(0)), _result(sh, op.result(1)) };
|
||||
@ -870,7 +867,7 @@ pub const Tensor = struct {
|
||||
self.value(),
|
||||
other.value(),
|
||||
used_opts,
|
||||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type),
|
||||
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
|
||||
loc,
|
||||
);
|
||||
|
||||
@ -1052,7 +1049,7 @@ pub const Tensor = struct {
|
||||
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
|
||||
|
||||
const mlir_ctx = self.getContext().mlirCtx();
|
||||
const res_type = mlir.ext.mlirType(mlir_ctx, self.shape().withDtype(to));
|
||||
const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to));
|
||||
const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc);
|
||||
return _result(self._shape.withDtype(to), op.result(0));
|
||||
}
|
||||
@ -1217,7 +1214,7 @@ pub const Tensor = struct {
|
||||
mlir_ctx,
|
||||
lhs.value(),
|
||||
rhs.value(),
|
||||
mlir.ext.mlirType(mlir_ctx, res_shape),
|
||||
mlirx.tensorType(mlir_ctx, res_shape),
|
||||
loc,
|
||||
.{
|
||||
.lhs_batching_dimensions = lhs_batching_axes.constSlice(),
|
||||
@ -1392,7 +1389,7 @@ pub const Tensor = struct {
|
||||
[2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } },
|
||||
);
|
||||
const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x});
|
||||
try testing.expectEqual(
|
||||
try std.testing.expectEqual(
|
||||
[2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } },
|
||||
try res.getValue([2][5]f32),
|
||||
);
|
||||
@ -1424,7 +1421,7 @@ pub const Tensor = struct {
|
||||
const op = dialect.stablehlo.transpose(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
mlir.ext.mlirType(self.getContext().mlirCtx(), res_shape),
|
||||
mlirx.tensorType(self.getContext().mlirCtx(), res_shape),
|
||||
loc,
|
||||
.{ .permutation = toI64(permutation) },
|
||||
);
|
||||
@ -1457,7 +1454,7 @@ pub const Tensor = struct {
|
||||
const reshaped_val = dialect.stablehlo.reshape(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
|
||||
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
|
||||
loc,
|
||||
);
|
||||
return _result(new_shape, reshaped_val.result(0));
|
||||
@ -1474,7 +1471,7 @@ pub const Tensor = struct {
|
||||
const reshaped_val = dialect.stablehlo.reshape(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
|
||||
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
|
||||
loc,
|
||||
);
|
||||
return _result(new_shape, reshaped_val.result(0));
|
||||
@ -1512,7 +1509,7 @@ pub const Tensor = struct {
|
||||
const reshaped_val = dialect.stablehlo.reshape(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
|
||||
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
|
||||
loc,
|
||||
);
|
||||
// log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] });
|
||||
@ -1586,7 +1583,7 @@ pub const Tensor = struct {
|
||||
|
||||
const mlir_ctx = self.getContext().mlirCtx();
|
||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type);
|
||||
const result_type = mlirx.tensorType(mlir_ctx, res_shape);
|
||||
const slice_op = dialect.stablehlo.slice(
|
||||
mlir_ctx,
|
||||
self.value(),
|
||||
@ -1620,15 +1617,15 @@ pub const Tensor = struct {
|
||||
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } });
|
||||
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
||||
try std.testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } });
|
||||
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
||||
try std.testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } });
|
||||
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
||||
try std.testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1838,7 +1835,7 @@ pub const Tensor = struct {
|
||||
|
||||
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
||||
const sh = Shape.init(.{n_steps}, dt);
|
||||
var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
|
||||
var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc);
|
||||
var res = _result(sh, op.result(0));
|
||||
|
||||
if (args.step != 1) {
|
||||
@ -1868,7 +1865,7 @@ pub const Tensor = struct {
|
||||
var op = dialect.stablehlo.iota(
|
||||
mlir_ctx,
|
||||
a,
|
||||
mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type),
|
||||
mlirx.tensorType(mlir_ctx, res_shape),
|
||||
loc,
|
||||
);
|
||||
return _result(res_shape, op.result(0));
|
||||
@ -1890,7 +1887,7 @@ pub const Tensor = struct {
|
||||
const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt });
|
||||
|
||||
const sh = Shape.init(.{args.steps}, dt);
|
||||
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
|
||||
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc);
|
||||
var res = _result(sh, iota_op.result(0));
|
||||
|
||||
if (args.steps != 1) {
|
||||
@ -1933,21 +1930,19 @@ pub const Tensor = struct {
|
||||
/// Returns a constant Tensor with the given value.
|
||||
pub fn constant(dimz: anytype, val: Data) Tensor {
|
||||
const sh = Shape.init(dimz, val.dtype());
|
||||
const singleton_sh = Shape.init(.{}, val.dtype());
|
||||
const ctx = CompilationContext.current().mlirCtx();
|
||||
const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val });
|
||||
const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
||||
|
||||
var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type|
|
||||
dialect.stablehlo.constant(ctx, res_type, elem_type, val.constSlice(), loc)
|
||||
var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type|
|
||||
dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc)
|
||||
else blk: {
|
||||
// Not all dtype can be serialized in the IR. If that's not possible, use f32.
|
||||
const val_f32 = val.as(f32);
|
||||
break :blk dialect.stablehlo.constant(ctx, res_type, .f32, std.mem.asBytes(&val_f32), loc);
|
||||
break :blk dialect.stablehlo.constant(ctx, &.{}, .f32, std.mem.asBytes(&val_f32), loc);
|
||||
};
|
||||
|
||||
if (sh.rank() > 0) {
|
||||
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc);
|
||||
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlirx.tensorType(ctx, sh), loc);
|
||||
}
|
||||
return _result(sh, constant_op.result(0)).convert(val.dtype());
|
||||
}
|
||||
@ -1955,10 +1950,9 @@ pub const Tensor = struct {
|
||||
/// Embeds a buffer with concrete values into an Mlir program.
|
||||
pub fn constantTensor(val: HostBuffer) Tensor {
|
||||
const ctx = CompilationContext.current().mlirCtx();
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
|
||||
const loc = ctx.location(@src());
|
||||
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
|
||||
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc);
|
||||
const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
|
||||
const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc);
|
||||
return _result(val.shape(), constant_op.result(0));
|
||||
}
|
||||
|
||||
@ -1994,7 +1988,7 @@ pub const Tensor = struct {
|
||||
return _result(res_shape, self.value());
|
||||
}
|
||||
const ctx = self.getContext();
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type);
|
||||
const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape);
|
||||
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
|
||||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
|
||||
|
||||
@ -2052,7 +2046,7 @@ pub const Tensor = struct {
|
||||
/// Reshapes the input Tensor with the given shape.
|
||||
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
|
||||
const output_shape = self._shape.reshape(output_shape_);
|
||||
const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape);
|
||||
const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape);
|
||||
const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape});
|
||||
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
|
||||
return _result(output_shape, reshape_value.result(0));
|
||||
@ -2846,9 +2840,9 @@ pub const Tensor = struct {
|
||||
const res = argmax.call(.{x});
|
||||
const max_ = res.values.getValue(f32);
|
||||
const max_idx = res.indices.getValue(i32);
|
||||
try testing.expectEqual(max_, 7.9);
|
||||
try std.testing.expectEqual(max_, 7.9);
|
||||
// We should always return the first max found.
|
||||
try testing.expectEqual(max_idx, 2);
|
||||
try std.testing.expectEqual(max_idx, 2);
|
||||
}
|
||||
|
||||
// Test with Nan
|
||||
@ -2857,8 +2851,8 @@ pub const Tensor = struct {
|
||||
const res = argmax.call(.{x});
|
||||
const max_ = try res.values.getValue(f32);
|
||||
const max_idx = try res.indices.getValue(i32);
|
||||
try testing.expect(std.math.isNan(max_));
|
||||
try testing.expectEqual(max_idx, 1);
|
||||
try std.testing.expect(std.math.isNan(max_));
|
||||
try std.testing.expectEqual(max_idx, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2907,7 +2901,7 @@ pub const Tensor = struct {
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
|
||||
const res = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{} });
|
||||
const res_cpu = try res.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
|
||||
try std.testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
|
||||
}
|
||||
// 3D Tensor, dim = 1, descending
|
||||
{
|
||||
@ -2920,7 +2914,7 @@ pub const Tensor = struct {
|
||||
});
|
||||
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{ .descending = true } });
|
||||
const res = try res_dev.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{
|
||||
try std.testing.expectEqualSlices(i32, &.{
|
||||
4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
|
||||
2, 0, 4, 4, 1, 3, 4, 4, 1, 0,
|
||||
1, 4, 2, 0, 2, 4, 2, 2, 0, 3,
|
||||
@ -2942,7 +2936,7 @@ pub const Tensor = struct {
|
||||
});
|
||||
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 3, .{} });
|
||||
const res = try res_dev.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{
|
||||
try std.testing.expectEqualSlices(i32, &.{
|
||||
2, 1, 3, 0,
|
||||
2, 3, 1, 0,
|
||||
3, 2, 0, 1,
|
||||
@ -3262,7 +3256,7 @@ pub const Tensor = struct {
|
||||
const z = try zml.Buffer.scalar(platform, 4, .i32);
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } });
|
||||
|
||||
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
|
||||
try std.testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
|
||||
}
|
||||
|
||||
{
|
||||
@ -3271,7 +3265,7 @@ pub const Tensor = struct {
|
||||
const z = try zml.Buffer.scalar(platform, 3, .i32);
|
||||
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } });
|
||||
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
|
||||
try std.testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
|
||||
}
|
||||
}
|
||||
|
||||
@ -3389,7 +3383,7 @@ pub const Tensor = struct {
|
||||
}._fwd,
|
||||
.{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) },
|
||||
);
|
||||
try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32));
|
||||
try std.testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32));
|
||||
}
|
||||
|
||||
{
|
||||
@ -3407,7 +3401,7 @@ pub const Tensor = struct {
|
||||
}._fwd,
|
||||
.{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) },
|
||||
);
|
||||
try testing.expectEqualDeep(
|
||||
try std.testing.expectEqualDeep(
|
||||
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||||
try res.getValue([2][5]f32),
|
||||
);
|
||||
@ -3427,7 +3421,7 @@ pub const Tensor = struct {
|
||||
}._fwd,
|
||||
.{ x, idx, y },
|
||||
);
|
||||
try testing.expectEqualDeep(
|
||||
try std.testing.expectEqualDeep(
|
||||
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||||
res.getValue([2][5]f32),
|
||||
);
|
||||
@ -3448,7 +3442,7 @@ pub const Tensor = struct {
|
||||
}._fwd,
|
||||
.{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) },
|
||||
);
|
||||
try testing.expectEqualDeep(
|
||||
try std.testing.expectEqualDeep(
|
||||
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||||
res.getValue([2][5]f32),
|
||||
);
|
||||
@ -3466,7 +3460,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
};
|
||||
const res = try zml.testing.compileAndCall(platform, A._fwd, .{ x, .{ idx_a, idx_b }, y });
|
||||
try testing.expectEqualDeep(
|
||||
try std.testing.expectEqualDeep(
|
||||
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||||
res.getValue([2][5]f32),
|
||||
);
|
||||
@ -3531,7 +3525,7 @@ pub const Tensor = struct {
|
||||
const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } });
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x});
|
||||
try testing.expectEqual(
|
||||
try std.testing.expectEqual(
|
||||
[2][2][2]u8{ .{
|
||||
.{ 1, 0 },
|
||||
.{ 0, 2 },
|
||||
@ -3582,7 +3576,7 @@ pub const Tensor = struct {
|
||||
});
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 });
|
||||
try testing.expectEqual(
|
||||
try std.testing.expectEqual(
|
||||
[3][3]u8{
|
||||
.{ 1, 0, 0 },
|
||||
.{ 1, 1, 0 },
|
||||
@ -3593,7 +3587,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 });
|
||||
try testing.expectEqual(
|
||||
try std.testing.expectEqual(
|
||||
[3][3]u8{
|
||||
.{ 1, 1, 0 },
|
||||
.{ 1, 1, 1 },
|
||||
@ -3604,7 +3598,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 });
|
||||
try testing.expectEqual(
|
||||
try std.testing.expectEqual(
|
||||
[3][3]u8{
|
||||
.{ 0, 0, 0 },
|
||||
.{ 1, 0, 0 },
|
||||
|
||||
@ -25,7 +25,7 @@ pub const nn = @import("nn.zig");
|
||||
pub const module = @import("module.zig");
|
||||
pub const meta = @import("meta.zig");
|
||||
pub const platform = @import("platform.zig");
|
||||
pub const mlir = @import("mlir.zig");
|
||||
pub const mlir = @import("mlirx.zig");
|
||||
pub const pjrt = @import("pjrtx.zig");
|
||||
pub const testing = @import("testing.zig");
|
||||
pub const torch = @import("torch.zig");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user