Your IP : 172.28.240.42


Current Path : /usr/local/go/src/cmd/compile/internal/ssa/
Upload File :
Current File : //usr/local/go/src/cmd/compile/internal/ssa/memcombine.go

// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssa

import (
	"cmd/compile/internal/base"
	"cmd/compile/internal/types"
	"cmd/internal/src"
	"sort"
)

// memcombine combines smaller loads and stores into larger ones.
// We ensure this generates good code for encoding/binary operations.
// It may help other cases also.
func memcombine(f *Func) {
	// This optimization requires that the architecture has
	// unaligned loads and unaligned stores.
	if !f.Config.unalignedOK {
		return
	}

	memcombineLoads(f)
	memcombineStores(f)
}

func memcombineLoads(f *Func) {
	// Find "OR trees" to start with.
	mark := f.newSparseSet(f.NumValues())
	defer f.retSparseSet(mark)
	var order []*Value

	// Mark all values that are the argument of an OR.
	for _, b := range f.Blocks {
		for _, v := range b.Values {
			if v.Op == OpOr16 || v.Op == OpOr32 || v.Op == OpOr64 {
				mark.add(v.Args[0].ID)
				mark.add(v.Args[1].ID)
			}
		}
	}
	for _, b := range f.Blocks {
		order = order[:0]
		for _, v := range b.Values {
			if v.Op != OpOr16 && v.Op != OpOr32 && v.Op != OpOr64 {
				continue
			}
			if mark.contains(v.ID) {
				// marked - means it is not the root of an OR tree
				continue
			}
			// Add the OR tree rooted at v to the order.
			// We use BFS here, but any walk that puts roots before leaves would work.
			i := len(order)
			order = append(order, v)
			for ; i < len(order); i++ {
				x := order[i]
				for j := 0; j < 2; j++ {
					a := x.Args[j]
					if a.Op == OpOr16 || a.Op == OpOr32 || a.Op == OpOr64 {
						order = append(order, a)
					}
				}
			}
		}
		for _, v := range order {
			max := f.Config.RegSize
			switch v.Op {
			case OpOr64:
			case OpOr32:
				max = 4
			case OpOr16:
				max = 2
			default:
				continue
			}
			for n := max; n > 1; n /= 2 {
				if combineLoads(v, n) {
					break
				}
			}
		}
	}
}

// A BaseAddress represents the address ptr+idx, where
// ptr is a pointer type and idx is an integer type.
// idx may be nil, in which case it is treated as 0.
type BaseAddress struct {
	ptr *Value
	idx *Value
}

// splitPtr returns the base address of ptr and any
// constant offset from that base.
// BaseAddress{ptr,nil},0 is always a valid result, but splitPtr
// tries to peel away as many constants into off as possible.
func splitPtr(ptr *Value) (BaseAddress, int64) {
	var idx *Value
	var off int64
	for {
		if ptr.Op == OpOffPtr {
			off += ptr.AuxInt
			ptr = ptr.Args[0]
		} else if ptr.Op == OpAddPtr {
			if idx != nil {
				// We have two or more indexing values.
				// Pick the first one we found.
				return BaseAddress{ptr: ptr, idx: idx}, off
			}
			idx = ptr.Args[1]
			if idx.Op == OpAdd32 || idx.Op == OpAdd64 {
				if idx.Args[0].Op == OpConst32 || idx.Args[0].Op == OpConst64 {
					off += idx.Args[0].AuxInt
					idx = idx.Args[1]
				} else if idx.Args[1].Op == OpConst32 || idx.Args[1].Op == OpConst64 {
					off += idx.Args[1].AuxInt
					idx = idx.Args[0]
				}
			}
			ptr = ptr.Args[0]
		} else {
			return BaseAddress{ptr: ptr, idx: idx}, off
		}
	}
}

func combineLoads(root *Value, n int64) bool {
	orOp := root.Op
	var shiftOp Op
	switch orOp {
	case OpOr64:
		shiftOp = OpLsh64x64
	case OpOr32:
		shiftOp = OpLsh32x64
	case OpOr16:
		shiftOp = OpLsh16x64
	default:
		return false
	}

	// Find n values that are ORed together with the above op.
	a := make([]*Value, 0, 8)
	a = append(a, root)
	for i := 0; i < len(a) && int64(len(a)) < n; i++ {
		v := a[i]
		if v.Uses != 1 && v != root {
			// Something in this subtree is used somewhere else.
			return false
		}
		if v.Op == orOp {
			a[i] = v.Args[0]
			a = append(a, v.Args[1])
			i--
		}
	}
	if int64(len(a)) != n {
		return false
	}

	// Check that the first entry to see what ops we're looking for.
	// All the entries should be of the form shift(extend(load)), maybe with no shift.
	v := a[0]
	if v.Op == shiftOp {
		v = v.Args[0]
	}
	var extOp Op
	if orOp == OpOr64 && (v.Op == OpZeroExt8to64 || v.Op == OpZeroExt16to64 || v.Op == OpZeroExt32to64) ||
		orOp == OpOr32 && (v.Op == OpZeroExt8to32 || v.Op == OpZeroExt16to32) ||
		orOp == OpOr16 && v.Op == OpZeroExt8to16 {
		extOp = v.Op
		v = v.Args[0]
	} else {
		return false
	}
	if v.Op != OpLoad {
		return false
	}
	base, _ := splitPtr(v.Args[0])
	mem := v.Args[1]
	size := v.Type.Size()

	if root.Block.Func.Config.arch == "S390X" {
		// s390x can't handle unaligned accesses to global variables.
		if base.ptr.Op == OpAddr {
			return false
		}
	}

	// Check all the entries, extract useful info.
	type LoadRecord struct {
		load   *Value
		offset int64 // offset of load address from base
		shift  int64
	}
	r := make([]LoadRecord, n, 8)
	for i := int64(0); i < n; i++ {
		v := a[i]
		if v.Uses != 1 {
			return false
		}
		shift := int64(0)
		if v.Op == shiftOp {
			if v.Args[1].Op != OpConst64 {
				return false
			}
			shift = v.Args[1].AuxInt
			v = v.Args[0]
			if v.Uses != 1 {
				return false
			}
		}
		if v.Op != extOp {
			return false
		}
		load := v.Args[0]
		if load.Op != OpLoad {
			return false
		}
		if load.Uses != 1 {
			return false
		}
		if load.Args[1] != mem {
			return false
		}
		p, off := splitPtr(load.Args[0])
		if p != base {
			return false
		}
		r[i] = LoadRecord{load: load, offset: off, shift: shift}
	}

	// Sort in memory address order.
	sort.Slice(r, func(i, j int) bool {
		return r[i].offset < r[j].offset
	})

	// Check that we have contiguous offsets.
	for i := int64(0); i < n; i++ {
		if r[i].offset != r[0].offset+i*size {
			return false
		}
	}

	// Check for reads in little-endian or big-endian order.
	shift0 := r[0].shift
	isLittleEndian := true
	for i := int64(0); i < n; i++ {
		if r[i].shift != shift0+i*size*8 {
			isLittleEndian = false
			break
		}
	}
	isBigEndian := true
	for i := int64(0); i < n; i++ {
		if r[i].shift != shift0-i*size*8 {
			isBigEndian = false
			break
		}
	}
	if !isLittleEndian && !isBigEndian {
		return false
	}

	// Find a place to put the new load.
	// This is tricky, because it has to be at a point where
	// its memory argument is live. We can't just put it in root.Block.
	// We use the block of the latest load.
	loads := make([]*Value, n, 8)
	for i := int64(0); i < n; i++ {
		loads[i] = r[i].load
	}
	loadBlock := mergePoint(root.Block, loads...)
	if loadBlock == nil {
		return false
	}
	// Find a source position to use.
	pos := src.NoXPos
	for _, load := range loads {
		if load.Block == loadBlock {
			pos = load.Pos
			break
		}
	}
	if pos == src.NoXPos {
		return false
	}

	// Check to see if we need byte swap before storing.
	needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
		isBigEndian && !root.Block.Func.Config.BigEndian
	if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
		return false
	}

	// This is the commit point.

	// First, issue load at lowest address.
	v = loadBlock.NewValue2(pos, OpLoad, sizeType(n*size), r[0].load.Args[0], mem)

	// Byte swap if needed,
	if needSwap {
		v = byteSwap(loadBlock, pos, v)
	}

	// Extend if needed.
	if n*size < root.Type.Size() {
		v = zeroExtend(loadBlock, pos, v, n*size, root.Type.Size())
	}

	// Shift if needed.
	if isLittleEndian && shift0 != 0 {
		v = leftShift(loadBlock, pos, v, shift0)
	}
	if isBigEndian && shift0-(n-1)*8 != 0 {
		v = leftShift(loadBlock, pos, v, shift0-(n-1)*8)
	}

	// Install with (Copy v).
	root.reset(OpCopy)
	root.AddArg(v)

	// Clobber the loads, just to prevent additional work being done on
	// subtrees (which are now unreachable).
	for i := int64(0); i < n; i++ {
		clobber(r[i].load)
	}
	return true
}

func memcombineStores(f *Func) {
	mark := f.newSparseSet(f.NumValues())
	defer f.retSparseSet(mark)
	var order []*Value

	for _, b := range f.Blocks {
		// Mark all stores which are not last in a store sequence.
		mark.clear()
		for _, v := range b.Values {
			if v.Op == OpStore {
				mark.add(v.MemoryArg().ID)
			}
		}

		// pick an order for visiting stores such that
		// later stores come earlier in the ordering.
		order = order[:0]
		for _, v := range b.Values {
			if v.Op != OpStore {
				continue
			}
			if mark.contains(v.ID) {
				continue // not last in a chain of stores
			}
			for {
				order = append(order, v)
				v = v.Args[2]
				if v.Block != b || v.Op != OpStore {
					break
				}
			}
		}

		// Look for combining opportunities at each store in queue order.
		for _, v := range order {
			if v.Op != OpStore { // already rewritten
				continue
			}

			size := v.Aux.(*types.Type).Size()
			if size >= f.Config.RegSize || size == 0 {
				continue
			}

			for n := f.Config.RegSize / size; n > 1; n /= 2 {
				if combineStores(v, n) {
					continue
				}
			}
		}
	}
}

// Try to combine the n stores ending in root.
// Returns true if successful.
func combineStores(root *Value, n int64) bool {
	// Helper functions.
	type StoreRecord struct {
		store  *Value
		offset int64
	}
	getShiftBase := func(a []StoreRecord) *Value {
		x := a[0].store.Args[1]
		y := a[1].store.Args[1]
		switch x.Op {
		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
			x = x.Args[0]
		default:
			return nil
		}
		switch y.Op {
		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
			y = y.Args[0]
		default:
			return nil
		}
		var x2 *Value
		switch x.Op {
		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
			x2 = x.Args[0]
		default:
		}
		var y2 *Value
		switch y.Op {
		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
			y2 = y.Args[0]
		default:
		}
		if y2 == x {
			// a shift of x and x itself.
			return x
		}
		if x2 == y {
			// a shift of y and y itself.
			return y
		}
		if x2 == y2 {
			// 2 shifts both of the same argument.
			return x2
		}
		return nil
	}
	isShiftBase := func(v, base *Value) bool {
		val := v.Args[1]
		switch val.Op {
		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
			val = val.Args[0]
		default:
			return false
		}
		if val == base {
			return true
		}
		switch val.Op {
		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
			val = val.Args[0]
		default:
			return false
		}
		return val == base
	}
	shift := func(v, base *Value) int64 {
		val := v.Args[1]
		switch val.Op {
		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
			val = val.Args[0]
		default:
			return -1
		}
		if val == base {
			return 0
		}
		switch val.Op {
		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
			val = val.Args[1]
		default:
			return -1
		}
		if val.Op != OpConst64 {
			return -1
		}
		return val.AuxInt
	}

	// Element size of the individual stores.
	size := root.Aux.(*types.Type).Size()
	if size*n > root.Block.Func.Config.RegSize {
		return false
	}

	// Gather n stores to look at. Check easy conditions we require.
	a := make([]StoreRecord, 0, 8)
	rbase, roff := splitPtr(root.Args[0])
	if root.Block.Func.Config.arch == "S390X" {
		// s390x can't handle unaligned accesses to global variables.
		if rbase.ptr.Op == OpAddr {
			return false
		}
	}
	a = append(a, StoreRecord{root, roff})
	for i, x := int64(1), root.Args[2]; i < n; i, x = i+1, x.Args[2] {
		if x.Op != OpStore {
			return false
		}
		if x.Block != root.Block {
			return false
		}
		if x.Uses != 1 { // Note: root can have more than one use.
			return false
		}
		if x.Aux.(*types.Type).Size() != size {
			return false
		}
		base, off := splitPtr(x.Args[0])
		if base != rbase {
			return false
		}
		a = append(a, StoreRecord{x, off})
	}
	// Before we sort, grab the memory arg the result should have.
	mem := a[n-1].store.Args[2]

	// Sort stores in increasing address order.
	sort.Slice(a, func(i, j int) bool {
		return a[i].offset < a[j].offset
	})

	// Check that everything is written to sequential locations.
	for i := int64(0); i < n; i++ {
		if a[i].offset != a[0].offset+i*size {
			return false
		}
	}

	// Memory location we're going to write at (the lowest one).
	ptr := a[0].store.Args[0]

	// Check for constant stores
	isConst := true
	for i := int64(0); i < n; i++ {
		switch a[i].store.Args[1].Op {
		case OpConst32, OpConst16, OpConst8:
		default:
			isConst = false
			break
		}
	}
	if isConst {
		// Modify root to do all the stores.
		var c int64
		mask := int64(1)<<(8*size) - 1
		for i := int64(0); i < n; i++ {
			s := 8 * size * int64(i)
			if root.Block.Func.Config.BigEndian {
				s = 8*size*(n-1) - s
			}
			c |= (a[i].store.Args[1].AuxInt & mask) << s
		}
		var cv *Value
		switch size * n {
		case 2:
			cv = root.Block.Func.ConstInt16(types.Types[types.TUINT16], int16(c))
		case 4:
			cv = root.Block.Func.ConstInt32(types.Types[types.TUINT32], int32(c))
		case 8:
			cv = root.Block.Func.ConstInt64(types.Types[types.TUINT64], c)
		}

		// Move all the stores to the root.
		for i := int64(0); i < n; i++ {
			v := a[i].store
			if v == root {
				v.Aux = cv.Type // widen store type
				v.SetArg(0, ptr)
				v.SetArg(1, cv)
				v.SetArg(2, mem)
			} else {
				clobber(v)
				v.Type = types.Types[types.TBOOL] // erase memory type
			}
		}
		return true
	}

	// Check that all the shift/trunc are of the same base value.
	shiftBase := getShiftBase(a)
	if shiftBase == nil {
		return false
	}
	for i := int64(0); i < n; i++ {
		if !isShiftBase(a[i].store, shiftBase) {
			return false
		}
	}

	// Check for writes in little-endian or big-endian order.
	isLittleEndian := true
	shift0 := shift(a[0].store, shiftBase)
	for i := int64(1); i < n; i++ {
		if shift(a[i].store, shiftBase) != shift0+i*8 {
			isLittleEndian = false
			break
		}
	}
	isBigEndian := true
	for i := int64(1); i < n; i++ {
		if shift(a[i].store, shiftBase) != shift0-i*8 {
			isBigEndian = false
			break
		}
	}
	if !isLittleEndian && !isBigEndian {
		return false
	}

	// Check to see if we need byte swap before storing.
	needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
		isBigEndian && !root.Block.Func.Config.BigEndian
	if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
		return false
	}

	// This is the commit point.

	// Modify root to do all the stores.
	sv := shiftBase
	if isLittleEndian && shift0 != 0 {
		sv = rightShift(root.Block, root.Pos, sv, shift0)
	}
	if isBigEndian && shift0-(n-1)*8 != 0 {
		sv = rightShift(root.Block, root.Pos, sv, shift0-(n-1)*8)
	}
	if sv.Type.Size() > size*n {
		sv = truncate(root.Block, root.Pos, sv, sv.Type.Size(), size*n)
	}
	if needSwap {
		sv = byteSwap(root.Block, root.Pos, sv)
	}

	// Move all the stores to the root.
	for i := int64(0); i < n; i++ {
		v := a[i].store
		if v == root {
			v.Aux = sv.Type // widen store type
			v.SetArg(0, ptr)
			v.SetArg(1, sv)
			v.SetArg(2, mem)
		} else {
			clobber(v)
			v.Type = types.Types[types.TBOOL] // erase memory type
		}
	}
	return true
}

func sizeType(size int64) *types.Type {
	switch size {
	case 8:
		return types.Types[types.TUINT64]
	case 4:
		return types.Types[types.TUINT32]
	case 2:
		return types.Types[types.TUINT16]
	default:
		base.Fatalf("bad size %d\n", size)
		return nil
	}
}

func truncate(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
	switch from*10 + to {
	case 82:
		return b.NewValue1(pos, OpTrunc64to16, types.Types[types.TUINT16], v)
	case 84:
		return b.NewValue1(pos, OpTrunc64to32, types.Types[types.TUINT32], v)
	case 42:
		return b.NewValue1(pos, OpTrunc32to16, types.Types[types.TUINT16], v)
	default:
		base.Fatalf("bad sizes %d %d\n", from, to)
		return nil
	}
}
func zeroExtend(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
	switch from*10 + to {
	case 24:
		return b.NewValue1(pos, OpZeroExt16to32, types.Types[types.TUINT32], v)
	case 28:
		return b.NewValue1(pos, OpZeroExt16to64, types.Types[types.TUINT64], v)
	case 48:
		return b.NewValue1(pos, OpZeroExt32to64, types.Types[types.TUINT64], v)
	default:
		base.Fatalf("bad sizes %d %d\n", from, to)
		return nil
	}
}

func leftShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
	s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
	size := v.Type.Size()
	switch size {
	case 8:
		return b.NewValue2(pos, OpLsh64x64, v.Type, v, s)
	case 4:
		return b.NewValue2(pos, OpLsh32x64, v.Type, v, s)
	case 2:
		return b.NewValue2(pos, OpLsh16x64, v.Type, v, s)
	default:
		base.Fatalf("bad size %d\n", size)
		return nil
	}
}
func rightShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
	s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
	size := v.Type.Size()
	switch size {
	case 8:
		return b.NewValue2(pos, OpRsh64Ux64, v.Type, v, s)
	case 4:
		return b.NewValue2(pos, OpRsh32Ux64, v.Type, v, s)
	case 2:
		return b.NewValue2(pos, OpRsh16Ux64, v.Type, v, s)
	default:
		base.Fatalf("bad size %d\n", size)
		return nil
	}
}
func byteSwap(b *Block, pos src.XPos, v *Value) *Value {
	switch v.Type.Size() {
	case 8:
		return b.NewValue1(pos, OpBswap64, v.Type, v)
	case 4:
		return b.NewValue1(pos, OpBswap32, v.Type, v)
	case 2:
		return b.NewValue1(pos, OpBswap16, v.Type, v)

	default:
		v.Fatalf("bad size %d\n", v.Type.Size())
		return nil
	}
}