diff --git a/sharktank/sharktank/types/sharding.py b/sharktank/sharktank/types/sharding.py index ff925ec9c..81d2f31a5 100644 --- a/sharktank/sharktank/types/sharding.py +++ b/sharktank/sharktank/types/sharding.py @@ -143,48 +143,43 @@ def theta_sharding(self) -> ThetaSharding: ) -class LinearSplitParallelWeightAndBiasSharding(ThetaLayerSharding): - def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0): - """Split one parallel dimension for both the weight and bias. - Since the weight is transposed before multiplying, the weight parallel - dimension is the same as the output(bias) dimension.""" +class LinearLayerSharding(ThetaLayerSharding): + def __init__( + self, premul_input: TensorSharding, weight: TensorSharding, bias: TensorSharding + ): super().__init__() - self.shard_count = shard_count - self.weight_and_bias_spit_dim = weight_and_bias_spit_dim + self.premul_input = premul_input + self.weight = weight + self.bias = bias def theta_sharding(self) -> ThetaSharding: return ThetaSharding( { - "premul_input": Replicated(shard_count=self.shard_count), - "weight": Split( - shard_count=self.shard_count, - shard_dim=self.weight_and_bias_spit_dim, - ), - "bias": Split( - shard_count=self.shard_count, - shard_dim=self.weight_and_bias_spit_dim, - ), + "premul_input": self.premul_input, + "weight": self.weight, + "bias": self.bias, } ) -class LinearSplitReductionDimSharding(ThetaLayerSharding): - def __init__(self, shard_count: int): - super().__init__() - self.shard_count = shard_count +class LinearSplitParallelWeightAndBiasSharding(LinearLayerSharding): + def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0): + """Split one parallel dimension for both the weight and bias. + Since the weight is transposed before multiplying, the weight parallel + dimension is the same as the output(bias) dimension.""" + super().__init__( + premul_input=Replicated(shard_count=shard_count), + weight=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim), + bias=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim), + ) - def theta_sharding(self) -> ThetaSharding: - return ThetaSharding( - { - "premul_input": Replicated(shard_count=self.shard_count), - "weight": Split( - shard_count=self.shard_count, - shard_dim=1, - ), - "bias": Replicated( - shard_count=self.shard_count, - ), - } + +class LinearSplitReductionDimSharding(LinearLayerSharding): + def __init__(self, shard_count: int): + super().__init__( + premul_input=Replicated(shard_count=shard_count), + weight=Split(shard_count=shard_count, shard_dim=1), + bias=Replicated(shard_count=shard_count), )